aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-09 14:20:12 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-09 14:24:05 +0100
commit2999a02f0c6a6f290ce45f28c998a1c000d48f67 (patch)
treef9d13cec08ab8c6c47e68df512cddc613552a7d2 /src/backends/reference
parent998517647d699d602e36f06b40d3f1d1ddaae7be (diff)
downloadarmnn-2999a02f0c6a6f290ce45f28c998a1c000d48f67.tar.gz
IVGCVSW-2862 Extend the Elementwise Workload to support QSymm16 Data Type
IVGCVSW-2863 Unit test per Elementwise operator with QSymm16 Data Type * Added QSymm16 support for Elementwise Operators * Added QSymm16 unit tests for Elementwise Operators Change-Id: I4e4e2938f9ed2cbbb1f05fb0f7dc476768550277 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp170
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp32
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp21
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp32
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp22
5 files changed, 245 insertions, 32 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index d2cf6f904a..3512d52acf 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -228,9 +228,10 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
{
bool supported = true;
- std::array<DataType,2> supportedTypes = {
+ std::array<DataType,3> supportedTypes = {
DataType::Float32,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
@@ -432,12 +433,33 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference division: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference division: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference division: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference division: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference division: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
@@ -606,12 +628,33 @@ bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference maximum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference maximum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference maximum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference maximum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
@@ -659,12 +702,33 @@ bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference minimum: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference minimum: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference minimum: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference minimum: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
@@ -672,12 +736,33 @@ bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference multiplication: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference multiplication: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference multiplication: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference multiplication: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
@@ -860,12 +945,33 @@ bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,3> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference subtraction: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference subtraction: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference subtraction: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference subtraction: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
} // namespace armnn
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 09b0246895..4b4e5449b4 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -96,6 +96,14 @@ BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateAdditionInt16Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefAdditionWorkload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::QuantisedSymm16>();
+}
+
BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
{
RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
@@ -112,6 +120,14 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateSubtractionInt16Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefSubtractionWorkload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::QuantisedSymm16>();
+}
+
BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
{
RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
@@ -128,6 +144,14 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateMultiplicationInt16Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::QuantisedSymm16>();
+}
+
BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload)
{
RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
@@ -144,6 +168,14 @@ BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateDivisionInt16Workload)
+{
+ RefCreateElementwiseWorkloadTest<RefDivisionWorkload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::QuantisedSymm16>();
+}
+
template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
{
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 3206b762ff..cbc56d14b7 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -228,6 +228,10 @@ ARMNN_AUTO_TEST_CASE(AdditionUint8, AdditionUint8Test)
ARMNN_AUTO_TEST_CASE(AddBroadcastUint8, AdditionBroadcastUint8Test)
ARMNN_AUTO_TEST_CASE(AddBroadcast1ElementUint8, AdditionBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(AdditionInt16, AdditionInt16Test)
+ARMNN_AUTO_TEST_CASE(AddBroadcastInt16, AdditionBroadcastInt16Test)
+ARMNN_AUTO_TEST_CASE(AddBroadcast1ElementInt16, AdditionBroadcast1ElementInt16Test)
+
// Sub
ARMNN_AUTO_TEST_CASE(SimpleSub, SubtractionTest)
ARMNN_AUTO_TEST_CASE(SubBroadcast1Element, SubtractionBroadcast1ElementTest)
@@ -237,6 +241,10 @@ ARMNN_AUTO_TEST_CASE(SubtractionUint8, SubtractionUint8Test)
ARMNN_AUTO_TEST_CASE(SubBroadcastUint8, SubtractionBroadcastUint8Test)
ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementUint8, SubtractionBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(SubtractionInt16, SubtractionInt16Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcastInt16, SubtractionBroadcastInt16Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementInt16, SubtractionBroadcast1ElementInt16Test)
+
// Div
ARMNN_AUTO_TEST_CASE(SimpleDivision, DivisionTest)
ARMNN_AUTO_TEST_CASE(DivisionByZero, DivisionByZeroTest)
@@ -248,6 +256,10 @@ ARMNN_AUTO_TEST_CASE(DivisionUint8, DivisionUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1Element, DivisionBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1DVector, DivisionBroadcast1DVectorUint8Test)
+ARMNN_AUTO_TEST_CASE(DivisionInt16, DivisionInt16Test)
+ARMNN_AUTO_TEST_CASE(DivisionInt16Broadcast1Element, DivisionBroadcast1ElementInt16Test)
+ARMNN_AUTO_TEST_CASE(DivisionInt16Broadcast1DVector, DivisionBroadcast1DVectorInt16Test)
+
// Equal
ARMNN_AUTO_TEST_CASE(SimpleEqual, EqualSimpleTest)
ARMNN_AUTO_TEST_CASE(EqualBroadcast1Element, EqualBroadcast1ElementTest)
@@ -271,11 +283,17 @@ ARMNN_AUTO_TEST_CASE(MaximumBroadcast1DVector, MaximumBroadcast1DVectorTest)
ARMNN_AUTO_TEST_CASE(MaximumUint8, MaximumUint8Test)
ARMNN_AUTO_TEST_CASE(MaximumBroadcast1ElementUint8, MaximumBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(MaximumBroadcast1DVectorUint8, MaximumBroadcast1DVectorUint8Test)
+ARMNN_AUTO_TEST_CASE(MaximumInt16, MaximumInt16Test)
+ARMNN_AUTO_TEST_CASE(MaximumBroadcast1ElementInt16, MaximumBroadcast1ElementInt16Test)
+ARMNN_AUTO_TEST_CASE(MaximumBroadcast1DVectorInt16, MaximumBroadcast1DVectorInt16Test)
// Min
ARMNN_AUTO_TEST_CASE(SimpleMinimum1, MinimumBroadcast1ElementTest1)
ARMNN_AUTO_TEST_CASE(SimpleMinimum2, MinimumBroadcast1ElementTest2)
ARMNN_AUTO_TEST_CASE(Minimum1DVectorUint8, MinimumBroadcast1DVectorUint8Test)
+ARMNN_AUTO_TEST_CASE(MinimumInt16, MinimumInt16Test)
+ARMNN_AUTO_TEST_CASE(MinimumBroadcast1ElementInt16, MinimumBroadcast1ElementInt16Test)
+ARMNN_AUTO_TEST_CASE(MinimumBroadcast1DVectorInt16, MinimumBroadcast1DVectorInt16Test)
// Mul
ARMNN_AUTO_TEST_CASE(SimpleMultiplication, MultiplicationTest)
@@ -284,6 +302,9 @@ ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVector, MultiplicationBroadcast1DV
ARMNN_AUTO_TEST_CASE(MultiplicationUint8, MultiplicationUint8Test)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1ElementUint8, MultiplicationBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorUint8, MultiplicationBroadcast1DVectorUint8Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationInt16, MultiplicationInt16Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1ElementInt16, MultiplicationBroadcast1ElementInt16Test)
+ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorInt16, MultiplicationBroadcast1DVectorInt16Test)
// Batch Norm
ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index cfa8ce7e91..95c75a576a 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -112,6 +112,22 @@ public:
}
};
+class QSymm16Decoder : public TypedIterator<const int16_t, Decoder>
+{
+public:
+ QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
class FloatEncoder : public TypedIterator<float, Encoder>
{
public:
@@ -152,4 +168,20 @@ public:
}
};
+class QSymm16Encoder : public TypedIterator<int16_t, Encoder>
+{
+public:
+ QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ void Set(const float& right) override
+ {
+ *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 6e6e1d5f21..1a30e7c9fb 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -64,6 +64,28 @@ void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() c
encodeIterator0);
break;
}
+ case armnn::DataType::QuantisedSymm16:
+ {
+ QSymm16Decoder decodeIterator0(GetInputTensorData<int16_t>(0, m_Data),
+ inputInfo0.GetQuantizationScale(),
+ inputInfo0.GetQuantizationOffset());
+
+ QSymm16Decoder decodeIterator1(GetInputTensorData<int16_t>(1, m_Data),
+ inputInfo1.GetQuantizationScale(),
+ inputInfo1.GetQuantizationOffset());
+
+ QSymm16Encoder encodeIterator0(GetOutputTensorData<int16_t>(0, m_Data),
+ outputInfo.GetQuantizationScale(),
+ outputInfo.GetQuantizationOffset());
+
+ ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
+ break;
+ }
default:
BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!");
break;