aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/armnn/TypesUtils.hpp1
-rw-r--r--src/armnn/TypeUtils.hpp2
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp2
-rw-r--r--src/backends/cl/ClTensorHandle.hpp6
4 files changed, 9 insertions, 2 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index 3ed1dfbcb5..bb75b18c32 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -129,6 +129,7 @@ constexpr const char* GetDataTypeName(DataType dataType)
case DataType::Float32: return "Float32";
case DataType::QuantisedAsymm8: return "Unsigned8";
case DataType::Signed32: return "Signed32";
+ case DataType::Boolean: return "Boolean";
default:
return "Unknown";
diff --git a/src/armnn/TypeUtils.hpp b/src/armnn/TypeUtils.hpp
index 5bb040f780..f7d0e077c8 100644
--- a/src/armnn/TypeUtils.hpp
+++ b/src/armnn/TypeUtils.hpp
@@ -41,7 +41,7 @@ struct ResolveTypeImpl<DataType::Signed32>
template<>
struct ResolveTypeImpl<DataType::Boolean>
{
- using Type = bool;
+ using Type = uint8_t;
};
template<DataType DT>
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 32af42f7e1..4f69c0b7db 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -25,6 +25,8 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
return arm_compute::DataType::QASYMM8;
case armnn::DataType::Signed32:
return arm_compute::DataType::S32;
+ case armnn::DataType::Boolean:
+ return arm_compute::DataType::U8;
default:
BOOST_ASSERT_MSG(false, "Unknown data type");
return arm_compute::DataType::UNKNOWN;
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index f791ee8fc9..59a6bee7f5 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -94,6 +94,7 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<float*>(memory));
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<uint8_t*>(memory));
@@ -120,6 +121,7 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
this->GetTensor());
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
this->GetTensor());
@@ -194,6 +196,7 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<float*>(memory));
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<uint8_t*>(memory));
@@ -220,6 +223,7 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
this->GetTensor());
break;
+ case arm_compute::DataType::U8:
case arm_compute::DataType::QASYMM8:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
this->GetTensor());
@@ -240,4 +244,4 @@ private:
ITensorHandle* parentHandle = nullptr;
};
-} // namespace armnn \ No newline at end of file
+} // namespace armnn