aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonConstantWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonConstantWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonConstantWorkload.cpp42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/backends/neon/workloads/NeonConstantWorkload.cpp b/src/backends/neon/workloads/NeonConstantWorkload.cpp
index 1cffbe1448..f7c8a73f78 100644
--- a/src/backends/neon/workloads/NeonConstantWorkload.cpp
+++ b/src/backends/neon/workloads/NeonConstantWorkload.cpp
@@ -19,6 +19,32 @@
namespace armnn
{
+arm_compute::Status NeonConstantWorkloadValidate(const TensorInfo& output)
+{
+ const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ std::array<arm_compute::DataType,8> supportedTypes = {
+ arm_compute::DataType::BFLOAT16,
+ arm_compute::DataType::F16,
+ arm_compute::DataType::F32,
+ arm_compute::DataType::QASYMM8,
+ arm_compute::DataType::QASYMM8_SIGNED,
+ arm_compute::DataType::QSYMM16,
+ arm_compute::DataType::QSYMM8,
+ arm_compute::DataType::QSYMM8_PER_CHANNEL
+ };
+ auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
+
+ if (it != end(supportedTypes))
+ {
+ return arm_compute::Status{};
+ }
+ else
+ {
+ return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
+ }
+}
+
NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
const WorkloadInfo& info)
: BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
@@ -68,6 +94,22 @@ void NeonConstantWorkload::Execute() const
CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
break;
}
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ {
+ CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
+ break;
+ }
+ case arm_compute::DataType::QSYMM16:
+ {
+ CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int16_t>(), output);
+ break;
+ }
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+ {
+ CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
+ break;
+ }
default:
{
ARMNN_ASSERT_MSG(false, "Unknown data type");