aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/tosaCommon/CMakeLists.txt1
-rw-r--r--src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp5
-rw-r--r--src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp2
-rw-r--r--src/backends/tosaReference/CMakeLists.txt1
-rw-r--r--src/backends/tosaReference/TosaRefLayerSupport.cpp29
5 files changed, 21 insertions, 17 deletions
diff --git a/src/backends/tosaCommon/CMakeLists.txt b/src/backends/tosaCommon/CMakeLists.txt
index 61434edc96..83737d3bd3 100644
--- a/src/backends/tosaCommon/CMakeLists.txt
+++ b/src/backends/tosaCommon/CMakeLists.txt
@@ -4,6 +4,7 @@
#
include_directories(SYSTEM ${FLATBUFFERS_INCLUDE_PATH})
+include_directories(SYSTEM ${PROJECT_SOURCE_DIR}/third-party/half)
include_directories(SYSTEM ${TOSA_SERIALIZATION_LIB_INCLUDE})
list(APPEND armnnTosaBackend_sources
diff --git a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
index b887721648..e11f293b12 100644
--- a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
+++ b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
@@ -19,9 +19,10 @@ DType ArmNNToDType(const DataType& type)
switch (type)
{
case DataType::Float16:
- case DataType::Float32:
case DataType::BFloat16:
- return DType_FLOAT;
+ return DType_FP16;
+ case DataType::Float32:
+ return DType_FP32;
case DataType::QAsymmU8:
return DType_UINT8;
case DataType::QSymmS8:
diff --git a/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp b/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp
index a2949d61ac..f4435bdf42 100644
--- a/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp
+++ b/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp
@@ -18,7 +18,7 @@ void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock
uint32_t numOutputs,
Op tosaOp,
std::string operatorString,
- DType dataType = DType_FLOAT)
+ DType dataType = DType_FP32)
{
std::string blockStr = operatorString + "_block_";
CHECK(basicBlock->GetName().find(blockStr) != std::string::npos);
diff --git a/src/backends/tosaReference/CMakeLists.txt b/src/backends/tosaReference/CMakeLists.txt
index fdec6d1106..c7de117fdd 100644
--- a/src/backends/tosaReference/CMakeLists.txt
+++ b/src/backends/tosaReference/CMakeLists.txt
@@ -4,6 +4,7 @@
#
include_directories(SYSTEM ${FLATBUFFERS_INCLUDE_PATH})
+include_directories(SYSTEM ${PROJECT_SOURCE_DIR}/third-party/half)
include_directories(SYSTEM ${TOSA_SERIALIZATION_LIB_INCLUDE})
include_directories(SYSTEM ${TOSA_REFERENCE_MODEL_INCLUDE})
diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp
index 80e982f1c4..18530bb535 100644
--- a/src/backends/tosaReference/TosaRefLayerSupport.cpp
+++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp
@@ -29,25 +29,26 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op,
bool supported = true;
std::array<Attribute, 1> supportedAttributes =
- {
- Attribute_NONE
- };
+ {
+ Attribute_NONE
+ };
// Check Attribute from operator (GetAttribute)
supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported,
std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str());
- std::array<DType, 8> supportedTypes =
- {
- DType_BOOL,
- DType_UINT8,
- DType_INT4,
- DType_INT8,
- DType_INT16,
- DType_INT32,
- DType_FLOAT,
- DType_UINT16
- };
+ std::array<DType, 9> supportedTypes =
+ {
+ DType_BOOL,
+ DType_UINT8,
+ DType_UINT16,
+ DType_INT4,
+ DType_INT8,
+ DType_INT16,
+ DType_INT32,
+ DType_FP16,
+ DType_FP32
+ };
for (auto tensor : inputs)
{