From 1342bed114effa8b183ba683189117cd730ab635 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 17 Nov 2021 15:01:00 +0000 Subject: IVGCVSW-6452 'Move CompatibleTypes.hpp to the armnnUtils library' * Moved CompatibleTypes.hpp to include folder * Added implementation file to source CompatibleTypes.cpp Signed-off-by: Sadik Armagan Change-Id: I94d2bffdb82a0592943f497d4f57972151d9f2db --- src/armnn/CompatibleTypes.hpp | 65 ---------------------------- src/armnnUtils/CompatibleTypes.cpp | 65 ++++++++++++++++++++++++++++ src/backends/backendsCommon/TensorHandle.hpp | 22 +++++++--- 3 files changed, 82 insertions(+), 70 deletions(-) delete mode 100644 src/armnn/CompatibleTypes.hpp create mode 100644 src/armnnUtils/CompatibleTypes.cpp (limited to 'src') diff --git a/src/armnn/CompatibleTypes.hpp b/src/armnn/CompatibleTypes.hpp deleted file mode 100644 index e24d5dfc4c..0000000000 --- a/src/armnn/CompatibleTypes.hpp +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -#include -#include - -namespace armnn -{ - -template -bool CompatibleTypes(DataType) -{ - return false; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::Float32; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::Float16; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::BFloat16; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::Boolean || dataType == DataType::QAsymmU8; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::QSymmS8 - || dataType == DataType::QAsymmS8; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::QSymmS16; -} - -template<> -inline bool CompatibleTypes(DataType dataType) -{ - return dataType == DataType::Signed32; -} - -} //namespace armnn diff --git a/src/armnnUtils/CompatibleTypes.cpp b/src/armnnUtils/CompatibleTypes.cpp new file mode 100644 index 0000000000..9a3251d293 --- /dev/null +++ b/src/armnnUtils/CompatibleTypes.cpp @@ -0,0 +1,65 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include +#include + +#include "BFloat16.hpp" +#include "Half.hpp" + +using namespace armnn; + +namespace armnnUtils +{ + +template +bool CompatibleTypes(DataType) +{ + return false; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::Float32; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::Float16; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::BFloat16; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::Boolean || dataType == DataType::QAsymmU8; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::QSymmS8 + || dataType == DataType::QAsymmS8; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::QSymmS16; +} + +template<> +bool CompatibleTypes(DataType dataType) +{ + return dataType == DataType::Signed32; +} + +} //namespace armnnUtils diff --git a/src/backends/backendsCommon/TensorHandle.hpp b/src/backends/backendsCommon/TensorHandle.hpp index b898bd11a5..ba1fc16378 100644 --- a/src/backends/backendsCommon/TensorHandle.hpp +++ b/src/backends/backendsCommon/TensorHandle.hpp @@ -10,7 +10,7 @@ #include -#include +#include #include @@ -30,8 +30,14 @@ public: template const T* GetConstTensor() const { - ARMNN_ASSERT(CompatibleTypes(GetTensorInfo().GetDataType())); - return reinterpret_cast(m_Memory); + if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) + { + return reinterpret_cast(m_Memory); + } + else + { + throw armnn::Exception("Attempting to get not compatible type tensor!"); + } } const TensorInfo& GetTensorInfo() const @@ -79,8 +85,14 @@ public: template T* GetTensor() const { - ARMNN_ASSERT(CompatibleTypes(GetTensorInfo().GetDataType())); - return reinterpret_cast(m_MutableMemory); + if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) + { + return reinterpret_cast(m_MutableMemory); + } + else + { + throw armnn::Exception("Attempting to get not compatible type tensor!"); + } } protected: -- cgit v1.2.1