diff options
Diffstat (limited to 'src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp index e7c0f01cc9..cfdae63c26 100644 --- a/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp +++ b/src/backends/backendsCommon/test/FullyConnectedTestImpl.hpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include "TypeUtils.hpp" #include "WorkloadTestUtils.hpp" #include <backendsCommon/IBackendInternal.hpp> @@ -220,7 +221,7 @@ LayerTestResult<uint8_t, 2> FullyConnectedUint8Test( // Tests the fully connected layer with large values, optionally transposing weights. // Note this is templated for consistency, but the nature of this tests makes it unlikely to be useful in Uint8 mode. // -template<typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 2> FullyConnectedLargeTestCommon( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, @@ -252,10 +253,10 @@ LayerTestResult<T, 2> FullyConnectedLargeTestCommon( unsigned int biasShape[] = { outputChannels }; - inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::GetDataType<T>()); - outputTensorInfo = armnn::TensorInfo(2, outputShape, armnn::GetDataType<T>()); - weightsDesc = armnn::TensorInfo(2, weightsShape, armnn::GetDataType<T>()); - biasesDesc = armnn::TensorInfo(1, biasShape, armnn::GetDataType<T>()); + inputTensorInfo = armnn::TensorInfo(4, inputShape, ArmnnType); + outputTensorInfo = armnn::TensorInfo(2, outputShape, ArmnnType); + weightsDesc = armnn::TensorInfo(2, weightsShape, ArmnnType); + biasesDesc = armnn::TensorInfo(1, biasShape, ArmnnType); // Set quantization parameters if the requested type is a quantized type. if(armnn::IsQuantizedType<T>()) |