aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/ArithmeticTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/ArithmeticTestImpl.hpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp b/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
index f70bf48ca9..1d6cf1d99b 100644
--- a/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
+++ b/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
@@ -4,6 +4,8 @@
//
#pragma once
+#include "TypeUtils.hpp"
+
#include <armnn/INetwork.hpp>
#include <backendsCommon/test/CommonTestUtils.hpp>
@@ -49,7 +51,7 @@ INetworkPtr CreateArithmeticNetwork(const std::vector<TensorShape>& inputShapes,
return net;
}
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
const LayerType type,
const std::vector<T> expectedOutput)
@@ -60,7 +62,7 @@ void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
const TensorShape& outputShape = { 2, 2, 2, 2 };
// Builds up the structure of the network
- INetworkPtr net = CreateArithmeticNetwork<GetDataType<T>()>(inputShapes, outputShape, type);
+ INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
BOOST_TEST_CHECKPOINT("create a network");
@@ -76,7 +78,7 @@ void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
}
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
const LayerType type,
const std::vector<T> expectedOutput)
@@ -87,7 +89,7 @@ void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
const TensorShape& outputShape = { 1, 2, 2, 3 };
// Builds up the structure of the network
- INetworkPtr net = CreateArithmeticNetwork<GetDataType<T>()>(inputShapes, outputShape, type);
+ INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
BOOST_TEST_CHECKPOINT("create a network");