diff options
Diffstat (limited to 'src')
5 files changed, 21 insertions, 17 deletions
diff --git a/src/backends/tosaCommon/CMakeLists.txt b/src/backends/tosaCommon/CMakeLists.txt index 61434edc96..83737d3bd3 100644 --- a/src/backends/tosaCommon/CMakeLists.txt +++ b/src/backends/tosaCommon/CMakeLists.txt @@ -4,6 +4,7 @@ # include_directories(SYSTEM ${FLATBUFFERS_INCLUDE_PATH}) +include_directories(SYSTEM ${PROJECT_SOURCE_DIR}/third-party/half) include_directories(SYSTEM ${TOSA_SERIALIZATION_LIB_INCLUDE}) list(APPEND armnnTosaBackend_sources diff --git a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp index b887721648..e11f293b12 100644 --- a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp +++ b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp @@ -19,9 +19,10 @@ DType ArmNNToDType(const DataType& type) switch (type) { case DataType::Float16: - case DataType::Float32: case DataType::BFloat16: - return DType_FLOAT; + return DType_FP16; + case DataType::Float32: + return DType_FP32; case DataType::QAsymmU8: return DType_UINT8; case DataType::QSymmS8: diff --git a/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp b/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp index a2949d61ac..f4435bdf42 100644 --- a/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp +++ b/src/backends/tosaCommon/test/TosaOperatorMappingTests.cpp @@ -18,7 +18,7 @@ void AssertTosaOneToOneMappingBasicBlock(TosaSerializationBasicBlock* basicBlock uint32_t numOutputs, Op tosaOp, std::string operatorString, - DType dataType = DType_FLOAT) + DType dataType = DType_FP32) { std::string blockStr = operatorString + "_block_"; CHECK(basicBlock->GetName().find(blockStr) != std::string::npos); diff --git a/src/backends/tosaReference/CMakeLists.txt b/src/backends/tosaReference/CMakeLists.txt index fdec6d1106..c7de117fdd 100644 --- a/src/backends/tosaReference/CMakeLists.txt +++ b/src/backends/tosaReference/CMakeLists.txt @@ -4,6 +4,7 @@ # include_directories(SYSTEM ${FLATBUFFERS_INCLUDE_PATH}) +include_directories(SYSTEM ${PROJECT_SOURCE_DIR}/third-party/half) include_directories(SYSTEM ${TOSA_SERIALIZATION_LIB_INCLUDE}) include_directories(SYSTEM ${TOSA_REFERENCE_MODEL_INCLUDE}) diff --git a/src/backends/tosaReference/TosaRefLayerSupport.cpp b/src/backends/tosaReference/TosaRefLayerSupport.cpp index 80e982f1c4..18530bb535 100644 --- a/src/backends/tosaReference/TosaRefLayerSupport.cpp +++ b/src/backends/tosaReference/TosaRefLayerSupport.cpp @@ -29,25 +29,26 @@ static bool IsTosaLayerSupported(TosaSerializationOperator* op, bool supported = true; std::array<Attribute, 1> supportedAttributes = - { - Attribute_NONE - }; + { + Attribute_NONE + }; // Check Attribute from operator (GetAttribute) supported &= CheckSupportRule(TosaOperatorAttributeOfAny(op, supportedAttributes), reasonIfUnsupported, std::string("TOSA Reference addition: operator has an unsupported attribute.").c_str()); - std::array<DType, 8> supportedTypes = - { - DType_BOOL, - DType_UINT8, - DType_INT4, - DType_INT8, - DType_INT16, - DType_INT32, - DType_FLOAT, - DType_UINT16 - }; + std::array<DType, 9> supportedTypes = + { + DType_BOOL, + DType_UINT8, + DType_UINT16, + DType_INT4, + DType_INT8, + DType_INT16, + DType_INT32, + DType_FP16, + DType_FP32 + }; for (auto tensor : inputs) { |