aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp120
1 files changed, 120 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
new file mode 100644
index 0000000000..5dc77f8496
--- /dev/null
+++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
@@ -0,0 +1,120 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefElementwiseBinaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Maximum.hpp"
+#include "Minimum.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+template<typename DataType>
+void ExecuteFunction(std::vector<ITensorHandle*> inputs,
+ std::vector<ITensorHandle*> outputs,
+ BinaryOperation operation)
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map());
+ std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map());
+ std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map());
+
+ using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>;
+ using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>;
+ using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>;
+ using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
+ using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
+ using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
+
+ switch (operation)
+ {
+ case BinaryOperation::Add:
+ {
+ AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Div:
+ {
+ DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Maximum:
+ {
+ MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Minimum:
+ {
+ MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Mul:
+ {
+ MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Sub:
+ {
+ SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported binary operation ") +
+ GetBinaryOperationAsCString(operation), CHECK_LOCATION());
+ }
+ }
+}
+
+RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info)
+{}
+
+void RefElementwiseBinaryWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs,
+ std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute");
+
+ if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32)
+ {
+ ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ }
+ else
+ {
+ ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation);
+ }
+}
+
+} // namespace armnn