diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 25 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 6 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/workloads/Pad.cpp | 6 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefPadWorkload.cpp | 1 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefPadWorkload.hpp | 3 |
6 files changed, 39 insertions, 7 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 087acd2103..ac7f310c1d 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1240,12 +1240,27 @@ bool RefLayerSupport::IsPadSupported(const TensorInfo& input, const PadDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(output); ignore_unused(descriptor); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + bool supported = true; + + // Define supported output and inputs types. + std::array<DataType,3> supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference pad: input is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference pad: output is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference pad: input and output types are mismatched."); + + return supported; } bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 9a31533cba..183103c40c 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -402,7 +402,11 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMinimum( std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefPadFloat32Workload, RefPadUint8Workload>(descriptor, info); + if (IsQSymm16(info)) + { + return std::make_unique<RefPadQSymm16Workload>(descriptor, info); + } + return MakeWorkload<RefPadFloat32Workload, RefPadQAsymm8Workload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 7009dec098..bc64725747 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -751,6 +751,11 @@ ARMNN_AUTO_TEST_CASE(PadUint82dCustomPadding, PadUint82dCustomPaddingTest) ARMNN_AUTO_TEST_CASE(PadUint83d, PadUint83dTest) ARMNN_AUTO_TEST_CASE(PadUint84d, PadUint84dTest) +ARMNN_AUTO_TEST_CASE(Pad2dQSymm16, Pad2dTestCommon<armnn::DataType::QuantisedSymm16>, 2.0f, 0, 0.0f) +ARMNN_AUTO_TEST_CASE(Pad2dQSymm16CustomPadding, Pad2dTestCommon<armnn::DataType::QuantisedSymm16>, 2.0f, 0, 1.0f) +ARMNN_AUTO_TEST_CASE(Pad3dQSymm16, Pad3dTestCommon<armnn::DataType::QuantisedSymm16>, 2.0f, 0) +ARMNN_AUTO_TEST_CASE(Pad4dQSymm16, Pad4dTestCommon<armnn::DataType::QuantisedSymm16>, 2.0f, 0) + // Constant ARMNN_AUTO_TEST_CASE(Constant, ConstantTest) ARMNN_AUTO_TEST_CASE(ConstantUint8, ConstantUint8CustomQuantizationScaleAndOffsetTest) diff --git a/src/backends/reference/workloads/Pad.cpp b/src/backends/reference/workloads/Pad.cpp index 1e58124627..41435f47d2 100644 --- a/src/backends/reference/workloads/Pad.cpp +++ b/src/backends/reference/workloads/Pad.cpp @@ -175,5 +175,11 @@ template void Pad<uint8_t>(const TensorInfo& inputInfo, const uint8_t* inputData, uint8_t* outData, const float padValue); +template void Pad<int16_t>(const TensorInfo& inputInfo, + const TensorInfo& outputInfo, + std::vector<std::pair<unsigned int, unsigned int>> m_PadList, + const int16_t* inputData, + int16_t* outData, + const float padValue); } //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefPadWorkload.cpp b/src/backends/reference/workloads/RefPadWorkload.cpp index e9724c449f..5e59d83dc9 100644 --- a/src/backends/reference/workloads/RefPadWorkload.cpp +++ b/src/backends/reference/workloads/RefPadWorkload.cpp @@ -35,5 +35,6 @@ void RefPadWorkload<DataType>::Execute() const template class RefPadWorkload<DataType::Float32>; template class RefPadWorkload<DataType::QuantisedAsymm8>; +template class RefPadWorkload<DataType::QuantisedSymm16>; } //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp index 8c6d01351b..b1de53e930 100644 --- a/src/backends/reference/workloads/RefPadWorkload.hpp +++ b/src/backends/reference/workloads/RefPadWorkload.hpp @@ -31,6 +31,7 @@ public: }; using RefPadFloat32Workload = RefPadWorkload<DataType::Float32>; -using RefPadUint8Workload = RefPadWorkload<DataType::QuantisedAsymm8>; +using RefPadQAsymm8Workload = RefPadWorkload<DataType::QuantisedAsymm8>; +using RefPadQSymm16Workload = RefPadWorkload<DataType::QuantisedSymm16>; } //namespace armnn |