diff options
Diffstat (limited to 'src/backends/backendsCommon/test/MergerTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/MergerTestImpl.hpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/MergerTestImpl.hpp b/src/backends/backendsCommon/test/MergerTestImpl.hpp index e0b8233336..2bdfe286c9 100644 --- a/src/backends/backendsCommon/test/MergerTestImpl.hpp +++ b/src/backends/backendsCommon/test/MergerTestImpl.hpp @@ -4,6 +4,8 @@ // #pragma once +#include "TypeUtils.hpp" + #include <armnn/INetwork.hpp> #include <backendsCommon/test/CommonTestUtils.hpp> @@ -47,17 +49,18 @@ INetworkPtr CreateMergerNetwork(const std::vector<TensorShape>& inputShapes, return net; } -template<typename T> +template<armnn::DataType ArmnnType> void MergerDim0EndToEnd(const std::vector<BackendId>& backends) { using namespace armnn; + using T = ResolveType<ArmnnType>; unsigned int concatAxis = 0; const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }}; const TensorShape& outputShape = { 4, 3, 2, 2 }; // Builds up the structure of the network - INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis); + INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis); BOOST_TEST_CHECKPOINT("create a network"); @@ -110,17 +113,18 @@ void MergerDim0EndToEnd(const std::vector<BackendId>& backends) EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends); } -template<typename T> +template<armnn::DataType ArmnnType> void MergerDim1EndToEnd(const std::vector<BackendId>& backends) { using namespace armnn; + using T = ResolveType<ArmnnType>; unsigned int concatAxis = 1; const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }}; const TensorShape& outputShape = { 2, 6, 2, 2 }; // Builds up the structure of the network - INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis); + INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis); BOOST_TEST_CHECKPOINT("create a network"); @@ -173,17 +177,18 @@ void MergerDim1EndToEnd(const std::vector<BackendId>& backends) EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends); } -template<typename T> +template<armnn::DataType ArmnnType> void MergerDim2EndToEnd(const std::vector<BackendId>& backends) { using namespace armnn; + using T = ResolveType<ArmnnType>; unsigned int concatAxis = 2; const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }}; const TensorShape& outputShape = { 2, 3, 4, 2 }; // Builds up the structure of the network - INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis); + INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis); BOOST_TEST_CHECKPOINT("create a network"); @@ -236,7 +241,7 @@ void MergerDim2EndToEnd(const std::vector<BackendId>& backends) EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends); } -template<typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> void MergerDim3EndToEnd(const std::vector<BackendId>& backends) { using namespace armnn; @@ -246,7 +251,7 @@ void MergerDim3EndToEnd(const std::vector<BackendId>& backends) const TensorShape& outputShape = { 2, 3, 2, 4 }; // Builds up the structure of the network - INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis); + INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis); BOOST_TEST_CHECKPOINT("create a network"); |