aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-07 16:19:24 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-09-25 14:54:29 +0100
commit4a8692cf18ebd3c4de125274d5c840d7be64e3cd (patch)
treeb504b5f42a83a89c7a40bc3dea13f230c847cc0e
parenta6bf9121e7c26561ca7cb950020db6cb665596a2 (diff)
downloadarmnn-4a8692cf18ebd3c4de125274d5c840d7be64e3cd.tar.gz
IVGCVSW-1801 : Cl implementation for SUB
Change-Id: Ia2e1dda8653197454a50679d49020397f5327979
-rw-r--r--Android.mk3
-rw-r--r--CMakeLists.txt14
-rw-r--r--src/armnn/backends/ClLayerSupport.cpp7
-rw-r--r--src/armnn/backends/ClWorkloadFactory.cpp2
-rw-r--r--src/armnn/backends/ClWorkloads.hpp2
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp64
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp29
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp22
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp20
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp18
-rw-r--r--src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp20
-rw-r--r--src/armnn/backends/test/ArmComputeCl.cpp19
-rw-r--r--src/armnn/backends/test/CreateWorkloadCl.cpp101
-rw-r--r--src/armnn/backends/test/Reference.cpp2
-rw-r--r--src/armnn/test/CreateWorkload.hpp33
15 files changed, 311 insertions, 45 deletions
diff --git a/Android.mk b/Android.mk
index 796b4d8fc0..9c2373678d 100644
--- a/Android.mk
+++ b/Android.mk
@@ -48,6 +48,9 @@ LOCAL_SRC_FILES := \
src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp \
src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp \
src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp \
+ src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp \
+ src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp \
+ src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp \
src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp \
src/armnn/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp \
src/armnn/backends/ClWorkloads/ClConstantFloatWorkload.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ecf30b1ab6..777c3153e6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -491,14 +491,20 @@ if(ARMCOMPUTECL)
src/armnn/backends/ClWorkloads/ClActivationUint8Workload.hpp
src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp
src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp
- src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.cpp
- src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp
- src/armnn/backends/ClWorkloads/ClConvertFp32ToFp16Workload.cpp
- src/armnn/backends/ClWorkloads/ClConvertFp32ToFp16Workload.hpp
src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp
src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp
src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp
src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp
+ src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp
+ src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp
+ src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp
+ src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp
+ src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp
+ src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp
+ src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.cpp
+ src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp
+ src/armnn/backends/ClWorkloads/ClConvertFp32ToFp16Workload.cpp
+ src/armnn/backends/ClWorkloads/ClConvertFp32ToFp16Workload.hpp
src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp
src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.hpp
src/armnn/backends/ClWorkloads/ClBaseMergerWorkload.hpp
diff --git a/src/armnn/backends/ClLayerSupport.cpp b/src/armnn/backends/ClLayerSupport.cpp
index 7b5fee2175..3dba1ec94c 100644
--- a/src/armnn/backends/ClLayerSupport.cpp
+++ b/src/armnn/backends/ClLayerSupport.cpp
@@ -29,6 +29,7 @@
#include "ClWorkloads/ClPermuteWorkload.hpp"
#include "ClWorkloads/ClNormalizationFloatWorkload.hpp"
#include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
+#include "ClWorkloads/ClSubtractionFloatWorkload.hpp"
#include "ClWorkloads/ClLstmFloatWorkload.hpp"
#endif
@@ -255,8 +256,10 @@ bool IsSubtractionSupportedCl(const TensorInfo& input0,
const TensorInfo& output,
std::string* reasonIfUnsupported)
{
- // At the moment subtraction is not supported
- return false;
+ return FORWARD_CL_LAYER_SUPPORT_FUNC(ClSubtractionValidate(input0,
+ input1,
+ output,
+ reasonIfUnsupported));
}
bool IsFullyConnectedSupportedCl(const TensorInfo& input,
diff --git a/src/armnn/backends/ClWorkloadFactory.cpp b/src/armnn/backends/ClWorkloadFactory.cpp
index 75a2af8b5a..056a201783 100644
--- a/src/armnn/backends/ClWorkloadFactory.cpp
+++ b/src/armnn/backends/ClWorkloadFactory.cpp
@@ -172,7 +172,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClSubtractionFloatWorkload, ClSubtractionUint8Workload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateBatchNormalization(
diff --git a/src/armnn/backends/ClWorkloads.hpp b/src/armnn/backends/ClWorkloads.hpp
index e524a70c69..0800401a22 100644
--- a/src/armnn/backends/ClWorkloads.hpp
+++ b/src/armnn/backends/ClWorkloads.hpp
@@ -36,5 +36,7 @@
#include "backends/ClWorkloads/ClSoftmaxUint8Workload.hpp"
#include "backends/ClWorkloads/ClSplitterFloatWorkload.hpp"
#include "backends/ClWorkloads/ClSplitterUint8Workload.hpp"
+#include "backends/ClWorkloads/ClSubtractionFloatWorkload.hpp"
+#include "backends/ClWorkloads/ClSubtractionUint8Workload.hpp"
#include "backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
#include "backends/ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp
new file mode 100644
index 0000000000..2145ed4a2a
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClSubtractionBaseWorkload.hpp"
+
+#include "backends/ClTensorHandle.hpp"
+#include "backends/CpuTensorHandle.hpp"
+#include "backends/ArmComputeTensorUtils.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
+
+template <armnn::DataType... T>
+ClSubtractionBaseWorkload<T...>::ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : TypedWorkload<SubtractionQueueDescriptor, T...>(descriptor, info)
+{
+ this->m_Data.ValidateInputsOutputs("ClSubtractionBaseWorkload", 2, 1);
+
+ arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
+ m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy);
+}
+
+template <armnn::DataType... T>
+void ClSubtractionBaseWorkload<T...>::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionBaseWorkload_Execute");
+ m_Layer.run();
+}
+
+bool ClSubtractionValidate(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported)
+{
+ const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0);
+ const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
+
+ const arm_compute::Status aclStatus = arm_compute::CLArithmeticSubtraction::validate(&aclInput0Info,
+ &aclInput1Info,
+ &aclOutputInfo,
+ g_AclConvertPolicy);
+
+ const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
+ if (!supported && reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = aclStatus.error_description();
+ }
+
+ return supported;
+}
+
+} //namespace armnn
+
+template class armnn::ClSubtractionBaseWorkload<armnn::DataType::Float16, armnn::DataType::Float32>;
+template class armnn::ClSubtractionBaseWorkload<armnn::DataType::QuantisedAsymm8>;
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp
new file mode 100644
index 0000000000..e4595d405a
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/ClWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType... dataTypes>
+class ClSubtractionBaseWorkload : public TypedWorkload<SubtractionQueueDescriptor, dataTypes...>
+{
+public:
+ ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info);
+
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLArithmeticSubtraction m_Layer;
+};
+
+bool ClSubtractionValidate(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ std::string* reasonIfUnsupported);
+} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp
new file mode 100644
index 0000000000..3321e20100
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp
@@ -0,0 +1,22 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClSubtractionFloatWorkload.hpp"
+
+#include "backends/ClTensorHandle.hpp"
+#include "backends/CpuTensorHandle.hpp"
+#include "backends/ArmComputeTensorUtils.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+void ClSubtractionFloatWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionFloatWorkload_Execute");
+ ClSubtractionBaseWorkload::Execute();
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp
new file mode 100644
index 0000000000..34a5e40983
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ClSubtractionBaseWorkload.hpp"
+
+namespace armnn
+{
+
+class ClSubtractionFloatWorkload : public ClSubtractionBaseWorkload<DataType::Float16, DataType::Float32>
+{
+public:
+ using ClSubtractionBaseWorkload<DataType::Float16, DataType::Float32>::ClSubtractionBaseWorkload;
+ void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp
new file mode 100644
index 0000000000..966068d648
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp
@@ -0,0 +1,18 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClSubtractionUint8Workload.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+void ClSubtractionUint8Workload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionUint8Workload_Execute");
+ ClSubtractionBaseWorkload::Execute();
+}
+
+} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp
new file mode 100644
index 0000000000..15b2059615
--- /dev/null
+++ b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ClSubtractionBaseWorkload.hpp"
+
+namespace armnn
+{
+
+class ClSubtractionUint8Workload : public ClSubtractionBaseWorkload<DataType::QuantisedAsymm8>
+{
+public:
+ using ClSubtractionBaseWorkload<DataType::QuantisedAsymm8>::ClSubtractionBaseWorkload;
+ void Execute() const override;
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/test/ArmComputeCl.cpp b/src/armnn/backends/test/ArmComputeCl.cpp
index 275d570a6c..3303c3fb51 100644
--- a/src/armnn/backends/test/ArmComputeCl.cpp
+++ b/src/armnn/backends/test/ArmComputeCl.cpp
@@ -139,6 +139,25 @@ ARMNN_AUTO_TEST_CASE(UNSUPPORTED_L2Pooling2dSize9Uint8, L2Pooling2dSize9Uint8Tes
// Add
ARMNN_AUTO_TEST_CASE(SimpleAdd, AdditionTest)
ARMNN_AUTO_TEST_CASE(AddBroadcast1Element, AdditionBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(AddBroadcast, AdditionBroadcastTest)
+
+ARMNN_AUTO_TEST_CASE(AdditionUint8, AdditionUint8Test)
+ARMNN_AUTO_TEST_CASE(AddBroadcastUint8, AdditionBroadcastUint8Test)
+ARMNN_AUTO_TEST_CASE(AddBroadcast1ElementUint8, AdditionBroadcast1ElementUint8Test)
+
+// Sub
+ARMNN_AUTO_TEST_CASE(SimpleSub, SubtractionTest)
+
+// TODO :
+// 1, enable broadcast tests for SUB when COMPMID-1566 is implemented (IVGCVSW-1837)
+// 2, enable quantized tests for SUB when COMPMID-1564 is implemented (IVGCVSW-1836)
+
+// ARMNN_AUTO_TEST_CASE(SubBroadcast1Element, SubtractionBroadcast1ElementTest)
+// ARMNN_AUTO_TEST_CASE(SubBroadcast, SubtractionBroadcastTest)
+
+// ARMNN_AUTO_TEST_CASE(SubtractionUint8, SubtractionUint8Test)
+// ARMNN_AUTO_TEST_CASE(SubBroadcastUint8, SubtractionBroadcastUint8Test)
+// ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementUint8, SubtractionBroadcast1ElementUint8Test)
// Div
ARMNN_AUTO_TEST_CASE(SimpleDivision, DivisionTest)
diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp
index 96001a4b78..340279e619 100644
--- a/src/armnn/backends/test/CreateWorkloadCl.cpp
+++ b/src/armnn/backends/test/CreateWorkloadCl.cpp
@@ -47,15 +47,18 @@ BOOST_AUTO_TEST_CASE(CreateActivationFloat16Workload)
ClCreateActivationWorkloadTest<ClActivationFloatWorkload, armnn::DataType::Float16>();
}
-template <typename AdditionWorkloadType, armnn::DataType DataType>
-static void ClCreateAdditionWorkloadTest()
+template <typename WorkloadType,
+ typename DescriptorType,
+ typename LayerType,
+ armnn::DataType DataType>
+static void ClCreateArithmethicWorkloadTest()
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreateAdditionWorkloadTest<AdditionWorkloadType, DataType>(factory, graph);
+ auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph);
- // Checks that inputs/outputs are as we expect them (see definition of CreateAdditionWorkloadTest).
- AdditionQueueDescriptor queueDescriptor = workload->GetData();
+ // Checks that inputs/outputs are as we expect them (see definition of CreateSubtractionWorkloadTest).
+ DescriptorType queueDescriptor = workload->GetData();
auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]);
auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
@@ -66,12 +69,66 @@ static void ClCreateAdditionWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
{
- ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float32>();
+ ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload)
{
- ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float16>();
+ ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
+{
+ ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
+{
+ ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionFloat16WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Float16>();
}
template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
@@ -219,36 +276,6 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest)
ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>();
}
-
-template <typename MultiplicationWorkloadType, typename armnn::DataType DataType>
-static void ClCreateMultiplicationWorkloadTest()
-{
- Graph graph;
- ClWorkloadFactory factory;
-
- auto workload =
- CreateMultiplicationWorkloadTest<MultiplicationWorkloadType, DataType>(factory, graph);
-
- // Checks that inputs/outputs are as we expect them (see definition of CreateMultiplicationWorkloadTest).
- MultiplicationQueueDescriptor queueDescriptor = workload->GetData();
- auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
- auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]);
- auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(CompareIClTensorHandleShape(inputHandle1, {2, 3}));
- BOOST_TEST(CompareIClTensorHandleShape(inputHandle2, {2, 3}));
- BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {2, 3}));
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
-{
- ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float32>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
-{
- ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float16>();
-}
-
template <typename NormalizationWorkloadType, typename armnn::DataType DataType>
static void ClNormalizationWorkloadTest()
{
diff --git a/src/armnn/backends/test/Reference.cpp b/src/armnn/backends/test/Reference.cpp
index 5a5f79d965..20e68d0ea1 100644
--- a/src/armnn/backends/test/Reference.cpp
+++ b/src/armnn/backends/test/Reference.cpp
@@ -151,7 +151,7 @@ ARMNN_AUTO_TEST_CASE(SimpleSub, SubtractionTest)
ARMNN_AUTO_TEST_CASE(SubBroadcast1Element, SubtractionBroadcast1ElementTest)
ARMNN_AUTO_TEST_CASE(SubBroadcast, SubtractionBroadcastTest)
-ARMNN_AUTO_TEST_CASE(SubitionUint8, SubtractionUint8Test)
+ARMNN_AUTO_TEST_CASE(SubtractionUint8, SubtractionUint8Test)
ARMNN_AUTO_TEST_CASE(SubBroadcastUint8, SubtractionBroadcastUint8Test)
ARMNN_AUTO_TEST_CASE(SubBroadcast1ElementUint8, SubtractionBroadcast1ElementUint8Test)
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index 6d975d4011..fb562e2ad0 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -126,6 +126,39 @@ std::unique_ptr<AdditionWorkload> CreateAdditionWorkloadTest(armnn::IWorkloadFac
return workload;
}
+template <typename WorkloadType,
+ typename DescriptorType,
+ typename LayerType,
+ armnn::DataType DataType>
+std::unique_ptr<WorkloadType> CreateArithmeticWorkloadTest(armnn::IWorkloadFactory& factory,
+ armnn::Graph& graph)
+{
+ // Creates the layer we're testing.
+ Layer* const layer = graph.AddLayer<LayerType>("layer");
+
+ // Creates extra layers.
+ Layer* const input1 = graph.AddLayer<InputLayer>(1, "input1");
+ Layer* const input2 = graph.AddLayer<InputLayer>(2, "input2");
+ Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+
+ // Connects up.
+ armnn::TensorInfo tensorInfo({2, 3}, DataType);
+ Connect(input1, layer, tensorInfo, 0, 0);
+ Connect(input2, layer, tensorInfo, 0, 1);
+ Connect(layer, output, tensorInfo);
+ CreateTensorHandles(graph, factory);
+
+ // Makes the workload and checks it.
+ auto workload = MakeAndCheckWorkload<WorkloadType>(*layer, graph, factory);
+
+ DescriptorType queueDescriptor = workload->GetData();
+ BOOST_TEST(queueDescriptor.m_Inputs.size() == 2);
+ BOOST_TEST(queueDescriptor.m_Outputs.size() == 1);
+
+ // Returns so we can do extra, backend-specific tests.
+ return workload;
+}
+
template <typename BatchNormalizationFloat32Workload, armnn::DataType DataType>
std::unique_ptr<BatchNormalizationFloat32Workload> CreateBatchNormalizationWorkloadTest(
armnn::IWorkloadFactory& factory, armnn::Graph& graph)