diff options
-rw-r--r-- | include/armnn/TypesUtils.hpp | 1 | ||||
-rw-r--r-- | src/armnn/TypeUtils.hpp | 2 | ||||
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/ClTensorHandle.hpp | 6 |
4 files changed, 9 insertions, 2 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index 3ed1dfbcb5..bb75b18c32 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -129,6 +129,7 @@ constexpr const char* GetDataTypeName(DataType dataType) case DataType::Float32: return "Float32"; case DataType::QuantisedAsymm8: return "Unsigned8"; case DataType::Signed32: return "Signed32"; + case DataType::Boolean: return "Boolean"; default: return "Unknown"; diff --git a/src/armnn/TypeUtils.hpp b/src/armnn/TypeUtils.hpp index 5bb040f780..f7d0e077c8 100644 --- a/src/armnn/TypeUtils.hpp +++ b/src/armnn/TypeUtils.hpp @@ -41,7 +41,7 @@ struct ResolveTypeImpl<DataType::Signed32> template<> struct ResolveTypeImpl<DataType::Boolean> { - using Type = bool; + using Type = uint8_t; }; template<DataType DT> diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 32af42f7e1..4f69c0b7db 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -25,6 +25,8 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType) return arm_compute::DataType::QASYMM8; case armnn::DataType::Signed32: return arm_compute::DataType::S32; + case armnn::DataType::Boolean: + return arm_compute::DataType::U8; default: BOOST_ASSERT_MSG(false, "Unknown data type"); return arm_compute::DataType::UNKNOWN; diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index f791ee8fc9..59a6bee7f5 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -94,6 +94,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<float*>(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<uint8_t*>(memory)); @@ -120,6 +121,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), this->GetTensor()); @@ -194,6 +196,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<float*>(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<uint8_t*>(memory)); @@ -220,6 +223,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), this->GetTensor()); @@ -240,4 +244,4 @@ private: ITensorHandle* parentHandle = nullptr; }; -} // namespace armnn
\ No newline at end of file +} // namespace armnn |