diff options
Diffstat (limited to 'src/backends/reference/workloads')
4 files changed, 154 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index de6c042959..3592f2293d 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT # @@ -108,6 +108,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefDequantizeWorkload.hpp RefDetectionPostProcessWorkload.cpp RefDetectionPostProcessWorkload.hpp + RefElementwiseBinaryWorkload.cpp + RefElementwiseBinaryWorkload.hpp RefElementwiseUnaryWorkload.cpp RefElementwiseUnaryWorkload.hpp RefFakeQuantizationFloat32Workload.cpp diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp new file mode 100644 index 0000000000..5dc77f8496 --- /dev/null +++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp @@ -0,0 +1,120 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefElementwiseBinaryWorkload.hpp" + +#include "Decoders.hpp" +#include "ElementwiseFunction.hpp" +#include "Encoders.hpp" +#include "RefWorkloadUtils.hpp" +#include "Maximum.hpp" +#include "Minimum.hpp" + +#include <Profiling.hpp> + +#include <armnn/TypesUtils.hpp> + +#include <functional> + +namespace armnn +{ + +template<typename DataType> +void ExecuteFunction(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs, + BinaryOperation operation) +{ + const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); + + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); + + std::unique_ptr<Decoder<DataType>> input0 = MakeDecoder<DataType>(inputInfo0, inputs[0]->Map()); + std::unique_ptr<Decoder<DataType>> input1 = MakeDecoder<DataType>(inputInfo1, inputs[1]->Map()); + std::unique_ptr<Encoder<DataType>> output = MakeEncoder<DataType>(outputInfo, outputs[0]->Map()); + + using AddFunction = ElementwiseBinaryFunction<std::plus<DataType>>; + using DivFunction = ElementwiseBinaryFunction<std::divides<DataType>>; + using MaximumFunction = ElementwiseBinaryFunction<armnn::maximum<DataType>>; + using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>; + using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>; + using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>; + + switch (operation) + { + case BinaryOperation::Add: + { + AddFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Div: + { + DivFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Maximum: + { + MaximumFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Minimum: + { + MinimumFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Mul: + { + MulFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Sub: + { + SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + default: + { + throw InvalidArgumentException(std::string("Unsupported binary operation ") + + GetBinaryOperationAsCString(operation), CHECK_LOCATION()); + } + } +} + +RefElementwiseBinaryWorkload::RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& desc, + const WorkloadInfo& info) + : RefBaseWorkload<ElementwiseBinaryQueueDescriptor>(desc, info) +{} + +void RefElementwiseBinaryWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefElementwiseBinaryWorkload::ExecuteAsync(ExecutionData& executionData) +{ + + WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); +} + +void RefElementwiseBinaryWorkload::Execute(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseBinaryWorkload_Execute"); + + if (GetTensorInfo(inputs[0]).GetDataType() == DataType::Signed32) + { + ExecuteFunction<int32_t>(inputs, outputs, m_Data.m_Parameters.m_Operation); + } + else + { + ExecuteFunction<float>(inputs, outputs, m_Data.m_Parameters.m_Operation); + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp new file mode 100644 index 0000000000..37458a1705 --- /dev/null +++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.hpp @@ -0,0 +1,29 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" + +#include "RefBaseWorkload.hpp" +#include <armnn/backends/WorkloadData.hpp> + +namespace armnn +{ + +class RefElementwiseBinaryWorkload : public RefBaseWorkload<ElementwiseBinaryQueueDescriptor> +{ +public: + using RefBaseWorkload<ElementwiseBinaryQueueDescriptor>::m_Data; + + RefElementwiseBinaryWorkload(const ElementwiseBinaryQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index afed71bfff..dba880bafc 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -26,6 +26,7 @@ #include "RefDetectionPostProcessWorkload.hpp" #include "RefDequantizeWorkload.hpp" #include "RefElementwiseWorkload.hpp" +#include "RefElementwiseBinaryWorkload.hpp" #include "RefElementwiseUnaryWorkload.hpp" #include "RefFakeQuantizationFloat32Workload.hpp" #include "RefFillWorkload.hpp" |