diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/CompatibleTypes.hpp | 44 | ||||
-rw-r--r-- | src/armnn/test/UtilsTests.cpp | 11 |
2 files changed, 44 insertions, 11 deletions
diff --git a/src/armnn/CompatibleTypes.hpp b/src/armnn/CompatibleTypes.hpp new file mode 100644 index 0000000000..2449876544 --- /dev/null +++ b/src/armnn/CompatibleTypes.hpp @@ -0,0 +1,44 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "armnn/Types.hpp" +#include "Half.hpp" + +namespace armnn +{ + +template<typename T> +bool CompatibleTypes(DataType dataType) +{ + return false; +} + +template<> +inline bool CompatibleTypes<float>(DataType dataType) +{ + return dataType == DataType::Float32; +} + +template<> +inline bool CompatibleTypes<Half>(DataType dataType) +{ + return dataType == DataType::Float16; +} + +template<> +inline bool CompatibleTypes<uint8_t>(DataType dataType) +{ + return dataType == DataType::Boolean || dataType == DataType::QuantisedAsymm8; +} + +template<> +inline bool CompatibleTypes<int32_t>(DataType dataType) +{ + return dataType == DataType::Signed32; +} + +} //namespace armnn diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp index 9933137edc..c81a4b67b6 100644 --- a/src/armnn/test/UtilsTests.cpp +++ b/src/armnn/test/UtilsTests.cpp @@ -23,14 +23,6 @@ BOOST_AUTO_TEST_CASE(DataTypeSize) BOOST_TEST(armnn::GetDataTypeSize(armnn::DataType::Boolean) == 1); } -BOOST_AUTO_TEST_CASE(GetDataTypeTest) -{ - BOOST_TEST((armnn::GetDataType<float>() == armnn::DataType::Float32)); - BOOST_TEST((armnn::GetDataType<uint8_t>() == armnn::DataType::QuantisedAsymm8)); - BOOST_TEST((armnn::GetDataType<int32_t>() == armnn::DataType::Signed32)); - BOOST_TEST((armnn::GetDataType<bool>() == armnn::DataType::Boolean)); -} - BOOST_AUTO_TEST_CASE(PermuteDescriptorWithTooManyMappings) { BOOST_CHECK_THROW(armnn::PermuteDescriptor({ 0u, 1u, 2u, 3u, 4u }), armnn::InvalidArgumentException); @@ -81,9 +73,6 @@ BOOST_AUTO_TEST_CASE(HalfType) constexpr bool isHalfType = std::is_same<armnn::Half, ResolvedType>::value; BOOST_CHECK(isHalfType); - armnn::DataType dt = armnn::GetDataType<armnn::Half>(); - BOOST_CHECK(dt == armnn::DataType::Float16); - //Test utility functions return correct size BOOST_CHECK(GetDataTypeSize(armnn::DataType::Float16) == 2); |