aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp25
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp4
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp5
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp2
4 files changed, 27 insertions, 9 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index f2ab9edca7..b508dfd29d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1116,11 +1116,26 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &FalseFuncU8<>);
+ bool supported = true;
+ std::array<DataType,2> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference rsqrt: input type not supported");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference rsqrt: output type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference rsqrt: input and output types not matching");
+
+ supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+ "Reference Rsqrt: input and output shapes have different number of total elements");
+
+ return supported;
}
bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 1ef88a090e..cb26f2642b 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -402,10 +402,6 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescr
{
return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
- else if(IsUint8(info))
- {
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
- }
return std::make_unique<RefRsqrtWorkload>(descriptor, info);
}
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 5139888e39..dbcf20169c 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -677,6 +677,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
}
+BOOST_AUTO_TEST_CASE(CreateRsqrtUint8)
+{
+ RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedAsymm8>();
+}
+
template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
static void RefCreateL2NormalizationTest(DataLayout dataLayout)
{
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index fd01550186..8ebb725a6f 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -552,6 +552,8 @@ ARMNN_AUTO_TEST_CASE(Rsqrt2d, Rsqrt2dTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(Rsqrt3d, Rsqrt3dTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest<armnn::DataType::Float32>)
ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest<armnn::DataType::QuantisedAsymm8>)
// Permute
ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test)