aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-03-08 13:47:17 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2023-03-14 16:40:09 +0000
commit3ec3077b4eaedcc0c20ab5774bdbe365da541445 (patch)
treed601d2000897dec8691bf64cbddc9036f26b8034 /src/backends/reference/workloads
parenta088cd00b3cce672d26cdcb4965fc2a86b48f339 (diff)
downloadarmnn-3ec3077b4eaedcc0c20ab5774bdbe365da541445.tar.gz
IVGCVSW-3808 Add ElementwiseBinaryLayer
!android-nn-driver:9329 * Added ElementwiseBinaryLayer that can represent all ElementwiseBinary operations including Add, Div, Sub, Maximum, Mul and Minimum. * Updated Delegate to use ElementwiseBinaryLayer instead of the Add, Div, Sub, Maximum, Mul and Minimum layers. * Updated Deserializer to use ElementwiseBinaryLayer instead of the Add, Div, Sub, Maximum, Mul and Minimum layers. * Updated OnnxParser to use ElementwiseBinaryLayer instead of the Add layer. * Updated TfLiteParser to use ElementwiseBinaryLayer instead of the Add, Div, Sub, Maximum, Mul and Minimum layers. * Updated CL and Neon tests to use ElementwiseBinaryLayer. * Updated CL and Neon Backend Specific Optimizations to accept ElementBinaryLayers as well as Add, Div, Mul, Sub, Maximum and Minimum layers. Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I7cbb96b60eb01f0e2b57b0541016d48a08b86c75
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp120
-rw-r--r--src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp29
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp3
4 files changed, 154 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index de6c042959..3592f2293d 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -108,6 +108,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefDequantizeWorkload.hpp
RefDetectionPostProcessWorkload.cpp
RefDetectionPostProcessWorkload.hpp
+ RefElementwiseBinaryWorkload.cpp
+ RefElementwiseBinaryWorkload.hpp
RefElementwiseUnaryWorkload.cpp
RefElementwiseUnaryWorkload.hpp
RefFakeQuantizationFloat32Workload.cpp
diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
new file mode 100644
index 0000000000..5dc77f8496
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
@@ -0,0 +1,120 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefElementwiseBinaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Maximum.hpp"
+#include "Minimum.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+template<typename DataType>
+void ExecuteFunction(std::vector<ITensorHandle*> inputs,
+ std::vector<ITensorHandle*> outputs,
+ BinaryOperation operation)
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
+ std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
+ std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
+
+ using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
+ using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
+ using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
+ using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
+ using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
+ using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
+
+ switch (operation)
+ {
+ case BinaryOperation::Add:
+ {
+ AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Div:
+ {
+ DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Maximum:
+ {
+ MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Minimum:
+ {
+ MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Mul:
+ {
+ MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Sub:
+ {
+ SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported binary operation ") +
+ GetBinaryOperationAsCString(operation), CHECK_LOCATION());
+ }
+ }
+}
+
+RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
+{}
+
+void RefElementwiseBinaryWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
+ std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
+
+ if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
+ {
+ ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ }
+ else
+ {
+ ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp
new file mode 100644
index 0000000000..37458a1705
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp
@@ -0,0 +1,29 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include "RefBaseWorkload.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefElementwiseBinaryWorkload : public RefBaseWorkload<ElementwiseBinaryQueueDescriptor>
+{
+public:
+ using RefBaseWorkload<ElementwiseBinaryQueueDescriptor>::m_Data;
+
+ RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void Execute() const override;
+ void ExecuteAsync(ExecutionData& executionData) override;
+
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index afed71bfff..dba880bafc 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -26,6 +26,7 @@
#include "RefDetectionPostProcessWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefElementwiseWorkload.hpp"
+#include "RefElementwiseBinaryWorkload.hpp"
#include "RefElementwiseUnaryWorkload.hpp"
#include "RefFakeQuantizationFloat32Workload.hpp"
#include "RefFillWorkload.hpp"