diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 24 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 10 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 6 | ||||
-rw-r--r-- | src/backends/reference/workloads/Debug.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.cpp | 1 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.hpp | 3 |
6 files changed, 40 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ae668f31f3..f3758760fc 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -533,11 +533,25 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + bool supported = true; + + std::array<DataType,3> supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference debug: input type not supported"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference debug: output type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference debug: input and output types are mismatched"); + + return supported; } bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 7ae5b97dcf..d1189a6542 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -50,9 +50,9 @@ bool IsFloat16(const WorkloadInfo& info) return IsDataType<DataType::Float16>(info); } -bool IsUint8(const WorkloadInfo& info) +bool IsQSymm16(const WorkloadInfo& info) { - return IsDataType<DataType::QuantisedAsymm8>(info); + return IsDataType<DataType::QuantisedSymm16>(info); } RefWorkloadFactory::RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager) @@ -432,7 +432,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueD std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefDebugFloat32Workload, RefDebugUint8Workload>(descriptor, info); + if (IsQSymm16(info)) + { + return std::make_unique<RefDebugQSymm16Workload>(descriptor, info); + } + return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymm8Workload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor, diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index e978f4254d..5542c9ae4f 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -10,6 +10,7 @@ #include <reference/RefWorkloadFactory.hpp> +#include <backendsCommon/test/DebugTestImpl.hpp> #include <backendsCommon/test/DetectionPostProcessLayerTestImpl.hpp> #include <backendsCommon/test/LayerTests.hpp> @@ -1025,6 +1026,11 @@ ARMNN_AUTO_TEST_CASE(Debug3DUint8, Debug3DUint8Test) ARMNN_AUTO_TEST_CASE(Debug2DUint8, Debug2DUint8Test) ARMNN_AUTO_TEST_CASE(Debug1DUint8, Debug1DUint8Test) +ARMNN_AUTO_TEST_CASE(Debug4DQSymm16, Debug4DTest<armnn::DataType::QuantisedSymm16>) +ARMNN_AUTO_TEST_CASE(Debug3DQSymm16, Debug3DTest<armnn::DataType::QuantisedSymm16>) +ARMNN_AUTO_TEST_CASE(Debug2DQSymm16, Debug2DTest<armnn::DataType::QuantisedSymm16>) +ARMNN_AUTO_TEST_CASE(Debug1DQSymm16, Debug1DTest<armnn::DataType::QuantisedSymm16>) + // Gather ARMNN_AUTO_TEST_CASE(Gather1DParamsFloat, Gather1DParamsFloatTest) ARMNN_AUTO_TEST_CASE(Gather1DParamsUint8, Gather1DParamsUint8Test) diff --git a/src/backends/reference/workloads/Debug.cpp b/src/backends/reference/workloads/Debug.cpp index 594f428908..d1c9fdd8b8 100644 --- a/src/backends/reference/workloads/Debug.cpp +++ b/src/backends/reference/workloads/Debug.cpp @@ -97,4 +97,9 @@ template void Debug<uint8_t>(const TensorInfo& inputInfo, const std::string& layerName, unsigned int slotIndex); +template void Debug<int16_t>(const TensorInfo& inputInfo, + const int16_t* inputData, + LayerGuid guid, + const std::string& layerName, + unsigned int slotIndex); } // namespace armnn diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp index be2d82f6dc..325817b19f 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.cpp +++ b/src/backends/reference/workloads/RefDebugWorkload.cpp @@ -46,5 +46,6 @@ void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFuncti template class RefDebugWorkload<DataType::Float32>; template class RefDebugWorkload<DataType::QuantisedAsymm8>; +template class RefDebugWorkload<DataType::QuantisedSymm16>; } // namespace armnn diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp index 2985699f7b..6a1fceba0a 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.hpp +++ b/src/backends/reference/workloads/RefDebugWorkload.hpp @@ -38,6 +38,7 @@ private: }; using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>; -using RefDebugUint8Workload = RefDebugWorkload<DataType::QuantisedAsymm8>; +using RefDebugQAsymm8Workload = RefDebugWorkload<DataType::QuantisedAsymm8>; +using RefDebugQSymm16Workload = RefDebugWorkload<DataType::QuantisedSymm16>; } // namespace armnn |