aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/CommonTestUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/CommonTestUtils.hpp')
-rw-r--r--src/backends/backendsCommon/test/CommonTestUtils.hpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/test/CommonTestUtils.hpp b/src/backends/backendsCommon/test/CommonTestUtils.hpp
index e96edc8317..8c4da621ed 100644
--- a/src/backends/backendsCommon/test/CommonTestUtils.hpp
+++ b/src/backends/backendsCommon/test/CommonTestUtils.hpp
@@ -8,9 +8,11 @@
#include <Graph.hpp>
#include <SubgraphView.hpp>
#include <SubgraphViewSelector.hpp>
+#include <ResolveType.hpp>
#include <armnn/BackendRegistry.hpp>
+#include <armnn/Types.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
#include <test/TestUtils.hpp>
@@ -50,6 +52,23 @@ bool Contains(const MapType& map, const typename MapType::key_type& key)
return map.find(key) != map.end();
}
+// Utility template for comparing tensor elements
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+bool Compare(T a, T b, float tolerance = 0.000001f)
+{
+ if (ArmnnType == armnn::DataType::Boolean)
+ {
+ // NOTE: Boolean is represented as uint8_t (with zero equals
+ // false and everything else equals true), therefore values
+ // need to be casted to bool before comparing them
+ return static_cast<bool>(a) == static_cast<bool>(b);
+ }
+
+ // NOTE: All other types can be cast to float and compared with
+ // a certain level of tolerance
+ return std::fabs(static_cast<float>(a) - static_cast<float>(b)) <= tolerance;
+}
+
template <typename ConvolutionLayer>
void SetWeightAndBias(ConvolutionLayer* layer, const armnn::TensorInfo& weightInfo, const armnn::TensorInfo& biasInfo)
{