From ab9e52563f624d9782b97400f643d2632cc8d770 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 13 Jun 2019 17:27:46 +0100 Subject: IVGCVSW-3268 Add Reference workload support for the new Prelu Activation layer * Added reference workload for the PReLU Activation layer * Added factory methods * Added validation support * Added Int16 support * Added unit tests Change-Id: Ic950d908c5e0a335dccd2960a3ffab0f8b599876 Signed-off-by: Matteo Martincigh --- src/backends/reference/workloads/CMakeLists.txt | 4 +++ src/backends/reference/workloads/PreluImpl.cpp | 35 ++++++++++++++++++++++ src/backends/reference/workloads/PreluImpl.hpp | 21 +++++++++++++ .../reference/workloads/RefPreluWorkload.cpp | 35 ++++++++++++++++++++++ .../reference/workloads/RefPreluWorkload.hpp | 22 ++++++++++++++ src/backends/reference/workloads/RefWorkloads.hpp | 1 + 6 files changed, 118 insertions(+) create mode 100644 src/backends/reference/workloads/PreluImpl.cpp create mode 100644 src/backends/reference/workloads/PreluImpl.hpp create mode 100644 src/backends/reference/workloads/RefPreluWorkload.cpp create mode 100644 src/backends/reference/workloads/RefPreluWorkload.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 1ab38ccbcb..db0daa0310 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -36,6 +36,8 @@ list(APPEND armnnRefBackendWorkloads_sources Pad.hpp Pooling2d.cpp Pooling2d.hpp + PreluImpl.cpp + PreluImpl.hpp RefActivationWorkload.cpp RefActivationWorkload.hpp RefBatchNormalizationWorkload.cpp @@ -84,6 +86,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefPermuteWorkload.hpp RefPooling2dWorkload.cpp RefPooling2dWorkload.hpp + RefPreluWorkload.cpp + RefPreluWorkload.hpp RefQuantizeWorkload.cpp RefQuantizeWorkload.hpp RefReshapeWorkload.cpp diff --git a/src/backends/reference/workloads/PreluImpl.cpp b/src/backends/reference/workloads/PreluImpl.cpp new file mode 100644 index 0000000000..458025bb0a --- /dev/null +++ b/src/backends/reference/workloads/PreluImpl.cpp @@ -0,0 +1,35 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "PreluImpl.hpp" +#include "RefWorkloadUtils.hpp" +#include "Broadcast.hpp" + +namespace armnn +{ + +void PreluImpl(const PreluQueueDescriptor& data, + Decoder& inputData, + Decoder& alphaData, + Encoder& outputData) +{ + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& alphaInfo = GetTensorInfo(data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); + + const TensorShape& inputShape = inputInfo.GetShape(); + const TensorShape& alphaShape = alphaInfo.GetShape(); + const TensorShape& outputShape = outputInfo.GetShape(); + + // PReLU activation: f(x) = alpha * x for x < 0, f(x) = x for x >= 0 + auto prelu = [](float x, float alpha) + { + return x < 0 ? alpha * x : x; + }; + + BroadcastLoop(inputShape, alphaShape, outputShape).Unroll(prelu, 0, inputData, alphaData, outputData); +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/PreluImpl.hpp b/src/backends/reference/workloads/PreluImpl.hpp new file mode 100644 index 0000000000..9299b1c7f7 --- /dev/null +++ b/src/backends/reference/workloads/PreluImpl.hpp @@ -0,0 +1,21 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Encoders.hpp" +#include "Decoders.hpp" + +#include + +namespace armnn +{ + +void PreluImpl(const PreluQueueDescriptor& data, + Decoder& inputData, + Decoder& alphaData, + Encoder& outputData); + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefPreluWorkload.cpp b/src/backends/reference/workloads/RefPreluWorkload.cpp new file mode 100644 index 0000000000..cdc0a63711 --- /dev/null +++ b/src/backends/reference/workloads/RefPreluWorkload.cpp @@ -0,0 +1,35 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefPreluWorkload.hpp" + +#include "RefWorkloadUtils.hpp" +#include "PreluImpl.hpp" + +#include + +namespace armnn +{ + +RefPreluWorkload::RefPreluWorkload(const PreluQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload(descriptor, info) +{} + +void RefPreluWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefPreluWorkload_Execute"); + + std::unique_ptr> inputDecoder = MakeDecoder(GetTensorInfo(m_Data.m_Inputs[0]), + m_Data.m_Inputs[0]->Map()); + std::unique_ptr> alphaDecoder = MakeDecoder(GetTensorInfo(m_Data.m_Inputs[1]), + m_Data.m_Inputs[1]->Map()); + std::unique_ptr> outputEncoder = MakeEncoder(GetTensorInfo(m_Data.m_Outputs[0]), + m_Data.m_Outputs[0]->Map()); + + PreluImpl(m_Data, *inputDecoder, *alphaDecoder, *outputEncoder); +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefPreluWorkload.hpp b/src/backends/reference/workloads/RefPreluWorkload.hpp new file mode 100644 index 0000000000..72839e67dc --- /dev/null +++ b/src/backends/reference/workloads/RefPreluWorkload.hpp @@ -0,0 +1,22 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class RefPreluWorkload : public BaseWorkload +{ +public: + explicit RefPreluWorkload(const PreluQueueDescriptor& descriptor, + const WorkloadInfo& info); + virtual void Execute() const override; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index b14129146a..41b16fa56f 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -51,3 +51,4 @@ #include "RefDequantizeWorkload.hpp" #include "RefQuantizeWorkload.hpp" #include "RefReshapeWorkload.hpp" +#include "RefPreluWorkload.hpp" -- cgit v1.2.1