diff options
Diffstat (limited to 'src/backends/backendsCommon/test/StridedSliceTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/StridedSliceTestImpl.hpp | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/src/backends/backendsCommon/test/StridedSliceTestImpl.hpp b/src/backends/backendsCommon/test/StridedSliceTestImpl.hpp index 1633151108..1bf5c642ad 100644 --- a/src/backends/backendsCommon/test/StridedSliceTestImpl.hpp +++ b/src/backends/backendsCommon/test/StridedSliceTestImpl.hpp @@ -4,6 +4,7 @@ // #pragma once +#include "TypeUtils.hpp" #include "WorkloadTestUtils.hpp" #include <armnn/ArmNN.hpp> @@ -71,7 +72,7 @@ LayerTestResult<T, OutDim> StridedSliceTestImpl( return ret; } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> StridedSlice4DTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -87,8 +88,8 @@ LayerTestResult<T, 4> StridedSlice4DTest( desc.m_Parameters.m_End = {2, 2, 3, 1}; desc.m_Parameters.m_Stride = {1, 1, 1, 1}; - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -108,7 +109,7 @@ LayerTestResult<T, 4> StridedSlice4DTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> StridedSlice4DReverseTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -124,8 +125,8 @@ LayerTestResult<T, 4> StridedSlice4DReverseTest( desc.m_Parameters.m_End = {2, -3, 3, 1}; desc.m_Parameters.m_Stride = {1, -1, 1, 1}; - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -145,7 +146,7 @@ LayerTestResult<T, 4> StridedSlice4DReverseTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> StridedSliceSimpleStrideTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -161,8 +162,8 @@ LayerTestResult<T, 4> StridedSliceSimpleStrideTest( desc.m_Parameters.m_End = {3, 2, 3, 1}; desc.m_Parameters.m_Stride = {2, 2, 2, 1}; - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -184,7 +185,7 @@ LayerTestResult<T, 4> StridedSliceSimpleStrideTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> StridedSliceSimpleRangeMaskTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -202,8 +203,8 @@ LayerTestResult<T, 4> StridedSliceSimpleRangeMaskTest( desc.m_Parameters.m_BeginMask = (1 << 4) - 1; desc.m_Parameters.m_EndMask = (1 << 4) - 1; - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -227,7 +228,7 @@ LayerTestResult<T, 4> StridedSliceSimpleRangeMaskTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 2> StridedSliceShrinkAxisMaskTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -245,8 +246,8 @@ LayerTestResult<T, 2> StridedSliceShrinkAxisMaskTest( desc.m_Parameters.m_EndMask = (1 << 4) - 1; desc.m_Parameters.m_ShrinkAxisMask = (1 << 1) | (1 << 2); - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(2, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -266,7 +267,7 @@ LayerTestResult<T, 2> StridedSliceShrinkAxisMaskTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 3> StridedSlice3DTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -283,8 +284,8 @@ LayerTestResult<T, 3> StridedSlice3DTest( desc.m_Parameters.m_Stride = {2, 2, 2}; desc.m_Parameters.m_EndMask = (1 << 3) - 1; - inputTensorInfo = armnn::TensorInfo(3, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(3, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(3, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(3, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -306,7 +307,7 @@ LayerTestResult<T, 3> StridedSlice3DTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 3> StridedSlice3DReverseTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -322,8 +323,8 @@ LayerTestResult<T, 3> StridedSlice3DReverseTest( desc.m_Parameters.m_End = {-4, -4, -4}; desc.m_Parameters.m_Stride = {-2, -2, -2}; - inputTensorInfo = armnn::TensorInfo(3, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(3, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(3, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(3, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -345,7 +346,7 @@ LayerTestResult<T, 3> StridedSlice3DReverseTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 2> StridedSlice2DTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -362,8 +363,8 @@ LayerTestResult<T, 2> StridedSlice2DTest( desc.m_Parameters.m_Stride = {2, 2}; desc.m_Parameters.m_EndMask = (1 << 2) - 1; - inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(2, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { @@ -385,7 +386,7 @@ LayerTestResult<T, 2> StridedSlice2DTest( workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, input, outputExpected, desc); } -template <typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 2> StridedSlice2DReverseTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -403,8 +404,8 @@ LayerTestResult<T, 2> StridedSlice2DReverseTest( desc.m_Parameters.m_BeginMask = (1 << 2) - 1; desc.m_Parameters.m_EndMask = (1 << 2) - 1; - inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(2, outputShape, ArmnnType); std::vector<float> input = std::vector<float>( { |