aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/PermuteTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/PermuteTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/PermuteTestImpl.hpp38
1 files changed, 28 insertions, 10 deletions
diff --git a/src/backends/backendsCommon/test/PermuteTestImpl.hpp b/src/backends/backendsCommon/test/PermuteTestImpl.hpp
index 529f9d34e0..c8120a41d8 100644
--- a/src/backends/backendsCommon/test/PermuteTestImpl.hpp
+++ b/src/backends/backendsCommon/test/PermuteTestImpl.hpp
@@ -5,6 +5,7 @@
#pragma once
#include "QuantizeHelper.hpp"
+#include "WorkloadTestUtils.hpp"
#include <armnn/ArmNN.hpp>
#include <armnn/Tensor.hpp>
@@ -13,11 +14,13 @@
#include <test/TensorHelpers.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
template<typename T>
LayerTestResult<T, 4> SimplePermuteTestImpl(
armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
armnn::PermuteDescriptor descriptor,
armnn::TensorInfo inputTensorInfo,
armnn::TensorInfo outputTensorInfo,
@@ -52,7 +55,9 @@ LayerTestResult<T, 4> SimplePermuteTestImpl(
return ret;
}
-LayerTestResult<float, 4> SimplePermuteFloat32TestCommon(armnn::IWorkloadFactory& workloadFactory)
+LayerTestResult<float, 4> SimplePermuteFloat32TestCommon(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
armnn::TensorInfo inputTensorInfo;
armnn::TensorInfo outputTensorInfo;
@@ -81,11 +86,14 @@ LayerTestResult<float, 4> SimplePermuteFloat32TestCommon(armnn::IWorkloadFactory
3.0f, 7.0f, 4.0f, 8.0f
});
- return SimplePermuteTestImpl<float>(workloadFactory, descriptor, inputTensorInfo,
+ return SimplePermuteTestImpl<float>(workloadFactory, memoryManager,
+ descriptor, inputTensorInfo,
outputTensorInfo, input, outputExpected);
}
-LayerTestResult<uint8_t, 4> SimplePermuteUint8TestCommon(armnn::IWorkloadFactory& workloadFactory)
+LayerTestResult<uint8_t, 4> SimplePermuteUint8TestCommon(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
armnn::TensorInfo inputTensorInfo;
armnn::TensorInfo outputTensorInfo;
@@ -116,12 +124,15 @@ LayerTestResult<uint8_t, 4> SimplePermuteUint8TestCommon(armnn::IWorkloadFactory
3, 7, 4, 8
});
- return SimplePermuteTestImpl<uint8_t>(workloadFactory, descriptor, inputTensorInfo,
+ return SimplePermuteTestImpl<uint8_t>(workloadFactory, memoryManager,
+ descriptor, inputTensorInfo,
outputTensorInfo, input, outputExpected);
}
LayerTestResult<float, 4>
-PermuteFloat32ValueSet1TestCommon(armnn::IWorkloadFactory& workloadFactory)
+PermuteFloat32ValueSet1TestCommon(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
armnn::TensorInfo inputTensorInfo;
armnn::TensorInfo outputTensorInfo;
@@ -150,12 +161,15 @@ PermuteFloat32ValueSet1TestCommon(armnn::IWorkloadFactory& workloadFactory)
3.0f, 13.0f, 23.0f, 33.0f,
});
- return SimplePermuteTestImpl<float>(workloadFactory, descriptor, inputTensorInfo,
+ return SimplePermuteTestImpl<float>(workloadFactory, memoryManager,
+ descriptor, inputTensorInfo,
outputTensorInfo, input, outputExpected);
}
LayerTestResult<float, 4>
-PermuteFloat32ValueSet2TestCommon(armnn::IWorkloadFactory& workloadFactory)
+PermuteFloat32ValueSet2TestCommon(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
armnn::TensorInfo inputTensorInfo;
armnn::TensorInfo outputTensorInfo;
@@ -184,12 +198,15 @@ PermuteFloat32ValueSet2TestCommon(armnn::IWorkloadFactory& workloadFactory)
31.0f, 32.0f, 33.0f,
});
- return SimplePermuteTestImpl<float>(workloadFactory, descriptor, inputTensorInfo,
+ return SimplePermuteTestImpl<float>(workloadFactory, memoryManager,
+ descriptor, inputTensorInfo,
outputTensorInfo, input, outputExpected);
}
LayerTestResult<float, 4>
-PermuteFloat32ValueSet3TestCommon(armnn::IWorkloadFactory& workloadFactory)
+PermuteFloat32ValueSet3TestCommon(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
armnn::TensorInfo inputTensorInfo;
armnn::TensorInfo outputTensorInfo;
@@ -220,6 +237,7 @@ PermuteFloat32ValueSet3TestCommon(armnn::IWorkloadFactory& workloadFactory)
3.0f, 13.0f, 23.0f, 33.0f, 43.0f, 53.0f,
});
- return SimplePermuteTestImpl<float>(workloadFactory, descriptor, inputTensorInfo,
+ return SimplePermuteTestImpl<float>(workloadFactory, memoryManager,
+ descriptor, inputTensorInfo,
outputTensorInfo, input, outputExpected);
}