diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 9ed7b7b437..2dc6d2a2b2 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -45,6 +45,38 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multi } } +armnn::DataType GetArmNNDataType(arm_compute::DataType dataType) +{ + switch(dataType) + { + case arm_compute::DataType::BFLOAT16: + return armnn::DataType::BFloat16; + case arm_compute::DataType::U8: + return armnn::DataType::Boolean; + case arm_compute::DataType::F16: + return armnn::DataType::Float16; + case arm_compute::DataType::F32: + return armnn::DataType::Float32; + case arm_compute::DataType::QASYMM8_SIGNED: + return armnn::DataType::QAsymmS8; + case arm_compute::DataType::QASYMM8: + return armnn::DataType::QAsymmU8; + case arm_compute::DataType::QSYMM16: + return armnn::DataType::QSymmS16; + case arm_compute::DataType::S64: + return armnn::DataType::Signed64; + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + return armnn::DataType::QSymmS8; + case arm_compute::DataType::QSYMM8: + return armnn::DataType::QSymmS8; + case arm_compute::DataType::S32: + return armnn::DataType::Signed32; + default: + ARMNN_ASSERT_MSG(false, "Unknown data type"); + return armnn::DataType::Float32; + } +} + arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions, unsigned int originalInputRank, const std::vector<unsigned int>& armnnAxes) |