diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/LayerSupportCommon.hpp | 6 | ||||
-rw-r--r-- | src/armnn/test/TensorHelpers.hpp | 21 | ||||
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 2 |
3 files changed, 24 insertions, 5 deletions
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp index 109728cd81..70b5f182f4 100644 --- a/src/armnn/LayerSupportCommon.hpp +++ b/src/armnn/LayerSupportCommon.hpp @@ -12,13 +12,15 @@ namespace armnn { -template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename ... Params> +template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename BooleanFunc, + typename ... Params> bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported, DataType dataType, Float16Func float16FuncPtr, Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, Int32Func int32FuncPtr, + BooleanFunc booleanFuncPtr, Params&&... params) { switch(dataType) @@ -31,6 +33,8 @@ bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported, return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); case DataType::Signed32: return int32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); + case DataType::Boolean: + return booleanFuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); default: return false; } diff --git a/src/armnn/test/TensorHelpers.hpp b/src/armnn/test/TensorHelpers.hpp index 06818d3918..fcaa0772a0 100644 --- a/src/armnn/test/TensorHelpers.hpp +++ b/src/armnn/test/TensorHelpers.hpp @@ -67,11 +67,16 @@ bool SelectiveCompare(T a, T b) return SelectiveComparer<T, armnn::IsQuantizedType<T>()>::Compare(a, b); }; - +template<typename T> +bool SelectiveCompareBoolean(T a, T b) +{ + return (((a == 0) && (b == 0)) || ((a != 0) && (b != 0))); +}; template <typename T, std::size_t n> boost::test_tools::predicate_result CompareTensors(const boost::multi_array<T, n>& a, - const boost::multi_array<T, n>& b) + const boost::multi_array<T, n>& b, + bool compareBoolean = false) { // Checks they are same shape. for (unsigned int i=0; i<n; i++) @@ -103,7 +108,17 @@ boost::test_tools::predicate_result CompareTensors(const boost::multi_array<T, n while (true) { - bool comparison = SelectiveCompare(a(indices), b(indices)); + bool comparison; + // As true for uint8_t is non-zero (1-255) we must have a dedicated compare for Booleans. + if(compareBoolean) + { + comparison = SelectiveCompareBoolean(a(indices), b(indices)); + } + else + { + comparison = SelectiveCompare(a(indices), b(indices)); + } + if (!comparison) { ++numFailedElements; diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index f489ca030c..04e91ad85e 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -41,7 +41,7 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe "The test name does not match the supportedness it is reporting"); if (testResult.supported) { - BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected)); + BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected, testResult.compareBoolean)); } } |