aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-06 16:46:34 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-09-25 14:54:29 +0100
commitf195f03e095a5c4dc6880be11af64cab83b5c94b (patch)
treed1c6d7d46ed70b915772bd50c5074d13443d9bca
parentc2044fe9d26a8b6afca48aee04bd5d29f8e27b8d (diff)
downloadarmnn-f195f03e095a5c4dc6880be11af64cab83b5c94b.tar.gz
IVGCVSW-1803 : add Ref Subtraction layer
Change-Id: I4c019d626f9369245eca6d549bbe7a28e141f198
-rw-r--r--Android.mk7
-rw-r--r--CMakeLists.txt16
-rw-r--r--src/armnn/backends/RefLayerSupport.cpp8
-rw-r--r--src/armnn/backends/RefWorkloadFactory.cpp2
-rw-r--r--src/armnn/backends/RefWorkloads.hpp2
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp31
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp21
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp41
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp21
-rw-r--r--src/armnn/backends/RefWorkloads/Subtraction.cpp44
-rw-r--r--src/armnn/backends/RefWorkloads/Subtraction.hpp20
-rw-r--r--src/armnn/backends/test/LayerTests.cpp160
-rw-r--r--src/armnn/backends/test/LayerTests.hpp10
-rw-r--r--src/armnn/backends/test/Reference.cpp9
14 files changed, 380 insertions, 12 deletions
diff --git a/Android.mk b/Android.mk
index a164535418..796b4d8fc0 100644
--- a/Android.mk
+++ b/Android.mk
@@ -128,11 +128,15 @@ LOCAL_SRC_FILES := \
src/armnn/backends/RefWorkloads/Multiplication.cpp \
src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp \
- src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/Broadcast.cpp \
src/armnn/backends/RefWorkloads/Addition.cpp \
+ src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
+ src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
+ src/armnn/backends/RefWorkloads/Subtraction.cpp \
+ src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp \
+ src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/ResizeBilinear.cpp \
src/armnn/backends/RefWorkloads/RefSoftmaxUint8Workload.cpp \
@@ -158,7 +162,6 @@ LOCAL_SRC_FILES := \
src/armnn/backends/RefWorkloads/RefConstantUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefConstantFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/Pooling2d.cpp \
- src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefPermuteWorkload.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7890cdfd02..ecf30b1ab6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -186,7 +186,18 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/Broadcast.cpp
src/armnn/backends/RefWorkloads/RefMergerUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefConstantUint8Workload.hpp
+ src/armnn/backends/RefWorkloads/Addition.cpp
src/armnn/backends/RefWorkloads/Addition.hpp
+ src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
+ src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
+ src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
+ src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
+ src/armnn/backends/RefWorkloads/Subtraction.cpp
+ src/armnn/backends/RefWorkloads/Subtraction.hpp
+ src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
+ src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
+ src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
+ src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
src/armnn/backends/RefWorkloads/ConvImpl.hpp
src/armnn/backends/RefWorkloads/RefResizeBilinearUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.hpp
@@ -207,7 +218,6 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/Multiplication.hpp
src/armnn/backends/RefWorkloads/RefActivationUint8Workload.hpp
src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp
- src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefPooling2dFloat32Workload.hpp
@@ -216,7 +226,6 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/RefFullyConnectedFloat32Workload.hpp
src/armnn/backends/RefWorkloads/Softmax.hpp
src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.hpp
- src/armnn/backends/RefWorkloads/Addition.cpp
src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp
src/armnn/backends/RefWorkloads/TensorBufferArrayView.hpp
src/armnn/backends/RefWorkloads/ResizeBilinear.cpp
@@ -237,7 +246,6 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/RefReshapeUint8Workload.hpp
src/armnn/backends/RefWorkloads/Activation.cpp
src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.hpp
- src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
src/armnn/backends/RefWorkloads/RefReshapeUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.hpp
src/armnn/backends/RefWorkloads/RefL2NormalizationFloat32Workload.cpp
@@ -266,9 +274,7 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/RefConstantUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefConstantFloat32Workload.cpp
src/armnn/backends/RefWorkloads/Pooling2d.cpp
- src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefConvolution2dFloat32Workload.hpp
- src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
src/armnn/backends/RefWorkloads/RefMergerFloat32Workload.cpp
src/armnn/backends/RefWorkloads/Pooling2d.hpp
src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
diff --git a/src/armnn/backends/RefLayerSupport.cpp b/src/armnn/backends/RefLayerSupport.cpp
index 5437574789..41f57f1677 100644
--- a/src/armnn/backends/RefLayerSupport.cpp
+++ b/src/armnn/backends/RefLayerSupport.cpp
@@ -135,8 +135,12 @@ bool IsSubtractionSupportedRef(const TensorInfo& input0,
const TensorInfo& output,
std::string* reasonIfUnsupported)
{
- // At the moment subtraction is not supported
- return false;
+ ignore_unused(input1);
+ ignore_unused(output);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
}
bool IsFullyConnectedSupportedRef(const TensorInfo& input,
diff --git a/src/armnn/backends/RefWorkloadFactory.cpp b/src/armnn/backends/RefWorkloadFactory.cpp
index 4de9274eb8..92e2506935 100644
--- a/src/armnn/backends/RefWorkloadFactory.cpp
+++ b/src/armnn/backends/RefWorkloadFactory.cpp
@@ -230,7 +230,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision(
std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction(
const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info);
}
} // namespace armnn
diff --git a/src/armnn/backends/RefWorkloads.hpp b/src/armnn/backends/RefWorkloads.hpp
index 98385ad5ac..910610c72e 100644
--- a/src/armnn/backends/RefWorkloads.hpp
+++ b/src/armnn/backends/RefWorkloads.hpp
@@ -57,3 +57,5 @@
#include "backends/RefWorkloads/RefConvertFp32ToFp16Workload.hpp"
#include "backends/RefWorkloads/RefDivisionFloat32Workload.hpp"
#include "backends/RefWorkloads/RefDivisionUint8Workload.hpp"
+#include "backends/RefWorkloads/RefSubtractionFloat32Workload.hpp"
+#include "backends/RefWorkloads/RefSubtractionUint8Workload.hpp"
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
new file mode 100644
index 0000000000..4440eedab7
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
@@ -0,0 +1,31 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefSubtractionFloat32Workload.hpp"
+
+#include "Subtraction.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+void RefSubtractionFloat32Workload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSubtractionFloat32Workload_Execute");
+
+ const TensorShape& inShape0 = GetTensorInfo(m_Data.m_Inputs[0]).GetShape();
+ const TensorShape& inShape1 = GetTensorInfo(m_Data.m_Inputs[1]).GetShape();
+ const TensorShape& outShape = GetTensorInfo(m_Data.m_Outputs[0]).GetShape();
+
+ const float* inData0 = GetInputTensorDataFloat(0, m_Data);
+ const float* inData1 = GetInputTensorDataFloat(1, m_Data);
+ float* outData = GetOutputTensorDataFloat(0, m_Data);
+
+ Subtraction(inShape0, inShape1, outShape, inData0, inData1, outData);
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
new file mode 100644
index 0000000000..b3f5ed9474
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/Workload.hpp"
+#include "backends/WorkloadData.hpp"
+
+namespace armnn
+{
+
+class RefSubtractionFloat32Workload : public Float32Workload<SubtractionQueueDescriptor>
+{
+public:
+ using Float32Workload<SubtractionQueueDescriptor>::Float32Workload;
+ virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
new file mode 100644
index 0000000000..8066762e48
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefSubtractionUint8Workload.hpp"
+
+#include "Subtraction.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+#include <vector>
+
+namespace armnn
+{
+
+void RefSubtractionUint8Workload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefSubtractionUint8Workload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ auto dequant0 = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0);
+ auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
+
+ std::vector<float> results(outputInfo.GetNumElements());
+
+ Subtraction(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
+
+ Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
new file mode 100644
index 0000000000..582533253b
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
@@ -0,0 +1,21 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/Workload.hpp"
+#include "backends/WorkloadData.hpp"
+
+namespace armnn
+{
+
+class RefSubtractionUint8Workload : public Uint8Workload<SubtractionQueueDescriptor>
+{
+public:
+ using Uint8Workload<SubtractionQueueDescriptor>::Uint8Workload;
+ virtual void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.cpp b/src/armnn/backends/RefWorkloads/Subtraction.cpp
new file mode 100644
index 0000000000..f25c8adb1c
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/Subtraction.cpp
@@ -0,0 +1,44 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Subtraction.hpp"
+#include "Broadcast.hpp"
+
+#include <functional>
+
+namespace
+{
+
+void ElementwiseSubtraction(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
+{
+ for (unsigned int i = 0; i < numElements; ++i)
+ {
+ outData[i] = inData0[i] - inData1[i];
+ }
+}
+
+} // namespace
+
+namespace armnn
+{
+
+void Subtraction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
+{
+ if (inShape0 == inShape1)
+ {
+ ElementwiseSubtraction(inShape0.GetNumElements(), inData0, inData1, outData);
+ }
+ else
+ {
+ BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::minus<float>(), 0, inData0, inData1, outData);
+ }
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.hpp b/src/armnn/backends/RefWorkloads/Subtraction.hpp
new file mode 100644
index 0000000000..3956797185
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/Subtraction.hpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+namespace armnn
+{
+
+void Subtraction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData);
+
+} //namespace armnn
diff --git a/src/armnn/backends/test/LayerTests.cpp b/src/armnn/backends/test/LayerTests.cpp
index 8683f116cf..b39daf6bbf 100644
--- a/src/armnn/backends/test/LayerTests.cpp
+++ b/src/armnn/backends/test/LayerTests.cpp
@@ -1002,7 +1002,7 @@ LayerTestResult<uint8_t, 4> AdditionBroadcast1ElementUint8Test(armnn::IWorkloadF
}
LayerTestResult<float,4> CompareAdditionTest(armnn::IWorkloadFactory& workloadFactory,
- armnn::IWorkloadFactory& refWorkloadFactory)
+ armnn::IWorkloadFactory& refWorkloadFactory)
{
unsigned int batchSize = 4;
unsigned int channels = 1;
@@ -3935,6 +3935,164 @@ LayerTestResult<uint8_t, 4> MultiplicationBroadcast1DVectorUint8Test(armnn::IWor
0);
}
+namespace
+{
+template <typename T>
+LayerTestResult<T, 4> SubtractionTestHelper(armnn::IWorkloadFactory& workloadFactory,
+ const unsigned int shape0[4],
+ const std::vector<T>& values0,
+ float scale0,
+ int32_t offset0,
+ const unsigned int shape1[4],
+ const std::vector<T> & values1,
+ float scale1,
+ int32_t offset1,
+ const unsigned int outShape[4],
+ const std::vector<T> & outValues,
+ float outScale,
+ int32_t outOffset)
+{
+ auto dataType = (std::is_same<T, uint8_t>::value ?
+ armnn::DataType::QuantisedAsymm8 :
+ armnn::DataType::Float32);
+
+ armnn::TensorInfo inputTensorInfo0(4, shape0, dataType);
+ armnn::TensorInfo inputTensorInfo1(4, shape1, dataType);
+ armnn::TensorInfo outputTensorInfo(4, outShape, dataType);
+
+ inputTensorInfo0.SetQuantizationScale(scale0);
+ inputTensorInfo0.SetQuantizationOffset(offset0);
+
+ inputTensorInfo1.SetQuantizationScale(scale1);
+ inputTensorInfo1.SetQuantizationOffset(offset1);
+
+ outputTensorInfo.SetQuantizationScale(outScale);
+ outputTensorInfo.SetQuantizationOffset(outOffset);
+
+ auto input0 = MakeTensor<T, 4>(inputTensorInfo0, values0);
+ auto input1 = MakeTensor<T, 4>(inputTensorInfo1, values1);
+
+ LayerTestResult<T, 4> result(outputTensorInfo);
+ result.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outValues);
+
+ std::unique_ptr<armnn::ITensorHandle> inputHandle0 = workloadFactory.CreateTensorHandle(inputTensorInfo0);
+ std::unique_ptr<armnn::ITensorHandle> inputHandle1 = workloadFactory.CreateTensorHandle(inputTensorInfo1);
+ std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+ armnn::SubtractionQueueDescriptor data;
+ armnn::WorkloadInfo info;
+ AddInputToWorkload(data, info, inputTensorInfo0, inputHandle0.get());
+ AddInputToWorkload(data, info, inputTensorInfo1, inputHandle1.get());
+ AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
+
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSubtraction(data, info);
+
+ inputHandle0->Allocate();
+ inputHandle1->Allocate();
+ outputHandle->Allocate();
+
+ CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
+ CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+
+ workloadFactory.Finalize();
+ workload->Execute();
+
+ CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
+
+ return result;
+}
+} // anonymous namespace
+
+LayerTestResult<uint8_t, 4> SubtractionUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 2, 2 };
+
+ std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+ std::vector<uint8_t> input1({ 1, 2, 1, 2 });
+ std::vector<uint8_t> output({ 3, 3, 5, 5 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 0.5f, 2,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<uint8_t, 4> SubtractionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 1, 1 };
+
+ std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+ std::vector<uint8_t> input1({ 2 });
+ std::vector<uint8_t> output({ 5, 6, 7, 8 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 0.5f, 2,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 3);
+}
+
+LayerTestResult<uint8_t, 4> SubtractionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 2, 1 };
+
+ std::vector<uint8_t> input0({ 10, 12, 14, 16 });
+ std::vector<uint8_t> input1({ 2, 1 });
+ std::vector<uint8_t> output({ 8, 11, 12, 15 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 1.0f, 0,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionTest(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 2, 2 };
+
+ std::vector<float> input0({ 1, 2, 3, 4 });
+ std::vector<float> input1({ 1, -1, 0, 2 });
+ std::vector<float> output({ 0, 3, 3, 2 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 1.0f, 0,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionBroadcast1ElementTest(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 1, 1 };
+
+ std::vector<float> input0({ 1, 2, 3, 4 });
+ std::vector<float> input1({ 10 });
+ std::vector<float> output({ -9, -8, -7, -6 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 1.0f, 0,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 0);
+}
+
+LayerTestResult<float, 4> SubtractionBroadcastTest(armnn::IWorkloadFactory& workloadFactory)
+{
+ const unsigned int shape0[] = { 1, 1, 2, 2 };
+ const unsigned int shape1[] = { 1, 1, 1, 2 };
+
+ std::vector<float> input0({ 1, 2, 3, 4 });
+ std::vector<float> input1({ 10, -5 });
+ std::vector<float> output({ -9, 7, -7, 9 });
+
+ return SubtractionTestHelper(workloadFactory,
+ shape0, input0, 1.0f, 0,
+ shape1, input1, 1.0f, 0,
+ shape0, output, 1.0f, 0);
+}
+
LayerTestResult<uint8_t, 4> ResizeBilinearNopUint8Test(armnn::IWorkloadFactory& workloadFactory)
{
constexpr unsigned int inputWidth = 4;
diff --git a/src/armnn/backends/test/LayerTests.hpp b/src/armnn/backends/test/LayerTests.hpp
index 06d789e783..5ca4c49c88 100644
--- a/src/armnn/backends/test/LayerTests.hpp
+++ b/src/armnn/backends/test/LayerTests.hpp
@@ -185,7 +185,11 @@ LayerTestResult<float, 4> AdditionBroadcast1ElementTest(armnn::IWorkloadFactory&
LayerTestResult<float, 4> AdditionBroadcastTest(armnn::IWorkloadFactory& workloadFactory);
LayerTestResult<float, 4> CompareAdditionTest(armnn::IWorkloadFactory& workloadFactory,
- armnn::IWorkloadFactory& refWorkloadFactory);
+ armnn::IWorkloadFactory& refWorkloadFactory);
+
+LayerTestResult<float, 4> SubtractionTest(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<float, 4> SubtractionBroadcast1ElementTest(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<float, 4> SubtractionBroadcastTest(armnn::IWorkloadFactory& workloadFactory);
LayerTestResult<float, 4> CompareActivationTest(armnn::IWorkloadFactory& workloadFactory,
armnn::IWorkloadFactory& refWorkloadFactory,
@@ -264,6 +268,10 @@ LayerTestResult<uint8_t, 4> AdditionUint8Test(armnn::IWorkloadFactory& workloadF
LayerTestResult<uint8_t, 4> AdditionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory);
LayerTestResult<uint8_t, 4> AdditionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<uint8_t, 4> SubtractionUint8Test(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<uint8_t, 4> SubtractionBroadcast1ElementUint8Test(armnn::IWorkloadFactory& workloadFactory);
+LayerTestResult<uint8_t, 4> SubtractionBroadcastUint8Test(armnn::IWorkloadFactory& workloadFactory);
+
LayerTestResult<uint8_t, 4> CompareActivationUint8Test(armnn::IWorkloadFactory& workloadFactory,
armnn::IWorkloadFactory& refWorkloadFactory,
armnn::ActivationFunction f);
diff --git a/src/armnn/backends/test/Reference.cpp b/src/armnn/backends/test/Reference.cpp
index 5b17bf3012..5a5f79d965 100644
--- a/src/armnn/backends/test/Reference.cpp
+++ b/src/armnn/backends/test/Reference.cpp
@@ -146,6 +146,15 @@ ARMNN_AUTO_TEST_CASE(AdditionUint8, AdditionUint8Test)
ARMNN_AUTO_TEST_CASE(AddBroadcastUint8, AdditionBroadcastUint8Test)
ARMNN_AUTO_TEST_CASE(AddBroadcast1ElementUint8, AdditionBroadcast1ElementUint8Test)
+// Sub
+ARMNN_AUTO_TEST_CASE(SimpleSub, SubtractionTest)
+ARMNN_AUTO_TEST_CASE(SubBroadcast1Element, SubtractionBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(SubBroadcast, SubtractionBroadcastTest)
+
+ARMNN_AUTO_TEST_CASE(SubitionUint8, SubtractionUint8Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcastUint8, SubtractionBroadcastUint8Test)
+ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementUint8, SubtractionBroadcast1ElementUint8Test)
+
// Div
ARMNN_AUTO_TEST_CASE(SimpleDivision, DivisionTest)
ARMNN_AUTO_TEST_CASE(DivisionByZero, DivisionByZeroTest)