From 0421e7f22d9ccd5d810b345731b766a96c841492 Mon Sep 17 00:00:00 2001 From: nikraj01 Date: Fri, 14 Jun 2019 09:40:34 +0100 Subject: IVGCVSW-3224 Add Uint8 support for Rsqrt Change-Id: I45598fc9b6d408b19d8d050e64c12b1d48535fa3 Signed-off-by: nikraj01 --- src/backends/backendsCommon/WorkloadData.cpp | 15 +++ src/backends/backendsCommon/test/LayerTests.hpp | 133 ++++++++++++--------- src/backends/reference/RefLayerSupport.cpp | 25 +++- src/backends/reference/RefWorkloadFactory.cpp | 4 - .../reference/test/RefCreateWorkloadTests.cpp | 5 + src/backends/reference/test/RefLayerTests.cpp | 2 + 6 files changed, 119 insertions(+), 65 deletions(-) diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 1e14b65c6c..20e125293a 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1468,6 +1468,21 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const "RsqrtQueueDescriptor", "input", "output"); + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "RsqrtQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + {workloadInfo.m_InputTensorInfos[0].GetDataType()}, + "RsqrtQueueDescriptor"); } void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp index 8bbd0d47c8..8a5a61145c 100644 --- a/src/backends/backendsCommon/test/LayerTests.hpp +++ b/src/backends/backendsCommon/test/LayerTests.hpp @@ -887,8 +887,8 @@ LayerTestResult Rsqrt2dTestCommon( const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::TensorInfo inputTensorInfo, const armnn::TensorInfo outputTensorInfo, - std::vector inputValues, - std::vector expectedOutputValues); + const std::vector& inputValues, + const std::vector& expectedOutputValues); template> LayerTestResult Rsqrt2dTest( @@ -1941,19 +1941,21 @@ std::vector ConvertToDataType(const std::vector& input, return output; } -template +template LayerTestResult Rsqrt2dTestCommon( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::TensorInfo inputTensorInfo, const armnn::TensorInfo outputTensorInfo, - std::vector inputValues, - std::vector expectedOutputValues) + const std::vector& inputValues, + const std::vector& expectedOutputValues) { - auto inputTensor = MakeTensor(inputTensorInfo, std::vector(inputValues)); + auto inputTensor = MakeTensor(inputTensorInfo, ConvertToDataType(inputValues,inputTensorInfo)); LayerTestResult result(outputTensorInfo); - result.outputExpected = MakeTensor(outputTensorInfo, std::vector(expectedOutputValues)); + + result.outputExpected = MakeTensor(outputTensorInfo, + ConvertToDataType(expectedOutputValues,outputTensorInfo)); std::unique_ptr inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); std::unique_ptr outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); @@ -1988,22 +1990,27 @@ LayerTestResult Rsqrt2dTest( const armnn::TensorShape inputShape{ 2, 2 }; const armnn::TensorShape outputShape{ 2, 2 }; - const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); - const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); + inputTensorInfo.SetQuantizationScale(0.1f); + inputTensorInfo.SetQuantizationOffset(0); + + armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + outputTensorInfo.SetQuantizationScale(0.1f); + outputTensorInfo.SetQuantizationOffset(0); - std::vector inputValues - { - 1.f, 4.f, - 16.f, 25.f - }; + std::vector inputValues + { + 1.f, 4.f, + 16.f, 25.f + }; - std::vector expectedOutputValues - { - 1.f, 0.5f, - 0.25f, 0.2f - }; + std::vector expectedOutputValues + { + 1.f, 0.5f, + 0.25f, 0.2f + }; - return Rsqrt2dTestCommon(workloadFactory, memoryManager, + return Rsqrt2dTestCommon(workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, inputValues, expectedOutputValues); } @@ -2016,25 +2023,31 @@ LayerTestResult Rsqrt3dTest( const armnn::TensorShape inputShape{ 3, 1, 2 }; const armnn::TensorShape outputShape{ 3, 1, 2 }; - const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); - const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); + inputTensorInfo.SetQuantizationScale(0.1f); + inputTensorInfo.SetQuantizationOffset(0); - std::vector inputValues - { - 1.f, 4.f, 16.f, - 25.f, 64.f, 100.f - }; + armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + outputTensorInfo.SetQuantizationScale(0.1f); + outputTensorInfo.SetQuantizationOffset(0); - std::vector expectedOutputValues - { - 1.f, 0.5f, 0.25f, - 0.2f, 0.125f, 0.1f - }; + std::vector inputValues + { + 1.f, 4.f, 16.f, + 25.f, 64.f, 100.f + }; - auto inputTensor = MakeTensor(inputTensorInfo, std::vector(inputValues)); + std::vector expectedOutputValues + { + 1.f, 0.5f, 0.25f, + 0.2f, 0.125f, 0.1f + }; + + auto inputTensor = MakeTensor(inputTensorInfo, ConvertToDataType(inputValues,inputTensorInfo)); LayerTestResult result(outputTensorInfo); - result.outputExpected = MakeTensor(outputTensorInfo, std::vector(expectedOutputValues)); + result.outputExpected = MakeTensor(outputTensorInfo, + ConvertToDataType(expectedOutputValues,outputTensorInfo)); std::unique_ptr inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); std::unique_ptr outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); @@ -2069,20 +2082,23 @@ LayerTestResult RsqrtZeroTest( const armnn::TensorShape inputShape{ 1, 2 }; const armnn::TensorShape outputShape{ 1, 2 }; - const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); - const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); + inputTensorInfo.SetQuantizationScale(0.1f); + + armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + outputTensorInfo.SetQuantizationScale(0.1f); - std::vector inputValues - { - 0.f, -0.f - }; + std::vector inputValues + { + 0.f, -0.f + }; - std::vector expectedOutputValues - { - INFINITY, -INFINITY - }; + std::vector expectedOutputValues + { + INFINITY, -INFINITY + }; - return Rsqrt2dTestCommon(workloadFactory, memoryManager, + return Rsqrt2dTestCommon(workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, inputValues, expectedOutputValues); } @@ -2095,20 +2111,25 @@ LayerTestResult RsqrtNegativeTest( const armnn::TensorShape inputShape{ 1, 2 }; const armnn::TensorShape outputShape{ 1, 2 }; - const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); - const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType); + inputTensorInfo.SetQuantizationScale(0.1f); + inputTensorInfo.SetQuantizationOffset(0); - std::vector inputValues - { - -25.f, -16.f - }; + armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType); + outputTensorInfo.SetQuantizationScale(0.1f); + outputTensorInfo.SetQuantizationOffset(0); - std::vector expectedOutputValues - { - -NAN, -NAN - }; + std::vector inputValues + { + -25.f, -16.f + }; + + std::vector expectedOutputValues + { + -NAN, -NAN + }; - return Rsqrt2dTestCommon(workloadFactory, memoryManager, + return Rsqrt2dTestCommon(workloadFactory, memoryManager, inputTensorInfo, outputTensorInfo, inputValues, expectedOutputValues); } diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index f2ab9edca7..b508dfd29d 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1116,11 +1116,26 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported) const { - ignore_unused(output); - return IsSupportedForDataTypeRef(reasonIfUnsupported, - input.GetDataType(), - &TrueFunc<>, - &FalseFuncU8<>); + bool supported = true; + std::array supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference rsqrt: input type not supported"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference rsqrt: output type not supported"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference rsqrt: input and output types not matching"); + + supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported, + "Reference Rsqrt: input and output shapes have different number of total elements"); + + return supported; } bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 1ef88a090e..cb26f2642b 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -402,10 +402,6 @@ std::unique_ptr RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescr { return MakeWorkload(descriptor, info); } - else if(IsUint8(info)) - { - return MakeWorkload(descriptor, info); - } return std::make_unique(descriptor, info); } diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 5139888e39..dbcf20169c 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -677,6 +677,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32) RefCreateRsqrtTest(); } +BOOST_AUTO_TEST_CASE(CreateRsqrtUint8) +{ + RefCreateRsqrtTest(); +} + template static void RefCreateL2NormalizationTest(DataLayout dataLayout) { diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index fd01550186..8ebb725a6f 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -552,6 +552,8 @@ ARMNN_AUTO_TEST_CASE(Rsqrt2d, Rsqrt2dTest) ARMNN_AUTO_TEST_CASE(Rsqrt3d, Rsqrt3dTest) ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest) ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest) +ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest) +ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest) // Permute ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test) -- cgit v1.2.1