aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
index e7c0f01cc9..cfdae63c26 100644
--- a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include "TypeUtils.hpp"
#include "WorkloadTestUtils.hpp"
#include <backendsCommon/IBackendInternal.hpp>
@@ -220,7 +221,7 @@ LayerTestResult<uint8_t, 2> FullyConnectedUint8Test(
// Tests the fully connected layer with large values, optionally transposing weights.
// Note this is templated for consistency, but the nature of this tests makes it unlikely to be useful in Uint8 mode.
//
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 2> FullyConnectedLargeTestCommon(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -252,10 +253,10 @@ LayerTestResult<T, 2> FullyConnectedLargeTestCommon(
unsigned int biasShape[] = { outputChannels };
- inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>());
- outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::GetDataType<T>());
- weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::GetDataType<T>());
- biasesDesc = armnn::TensorInfo(1, biasShape, armnn::GetDataType<T>());
+ inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType);
+ outputTensorInfo = armnn::TensorInfo(2, outputShape, ArmnnType);
+ weightsDesc = armnn::TensorInfo(2, weightsShape, ArmnnType);
+ biasesDesc = armnn::TensorInfo(1, biasShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
if(armnn::IsQuantizedType<T>())