From ae2c5f0350a7033f58578f9c509345445a639865 Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Wed, 24 Apr 2019 16:19:57 +0100 Subject: IVGCVSW-2982 Refactor reference Activation workload Change-Id: Ia3b9a56787cc68822a3c1635de82e03ecc0aae27 Signed-off-by: Nattapat Chaimanowong --- src/backends/reference/workloads/Activation.cpp | 150 ++++++++++++++---------- 1 file changed, 86 insertions(+), 64 deletions(-) (limited to 'src/backends/reference/workloads/Activation.cpp') diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp index ef4903074b..760c9a0ccd 100644 --- a/src/backends/reference/workloads/Activation.cpp +++ b/src/backends/reference/workloads/Activation.cpp @@ -11,6 +11,91 @@ namespace armnn { +float Activation(float in, + ActivationFunction function, + float a, + float b) +{ + float output; + + // Compute the result of the activation function. + switch (function) + { + case ActivationFunction::Linear: + { + output = a * in + b; + break; + } + case ActivationFunction::Sigmoid: + { + output = 1.f / (1.f + expf(-in)); + break; + } + case ActivationFunction::ReLu: + { + output = std::max(0.f, in); + break; + } + case ActivationFunction::BoundedReLu: + { + output = std::min(a, std::max(b, in)); + break; + } + case ActivationFunction::SoftReLu: + { + output = logf(1.0f + expf(in)); + break; + } + case ActivationFunction::LeakyReLu: + { + output = in > 0.0f ? in : (in * a); + break; + } + case ActivationFunction::Abs: + { + output = in < 0 ? -in : in; + break; + } + case ActivationFunction::Sqrt: + { + output = sqrtf(in); + break; + } + case ActivationFunction::Square: + { + output = in * in; + break; + } + case ActivationFunction::TanH: + { + output = a * tanhf(b * in); + break; + } + default: + { + throw InvalidArgumentException("Unsupported activation function"); + } + } + + return output; +} + + +void Activation(Decoder& in, + Encoder& out, + const TensorInfo& tensorInfo, + ActivationFunction function, + float a, + float b) +{ + for (size_t i = 0; i 0.0f ? input : (input * a); - break; - } - case ActivationFunction::Abs: - { - output = input < 0 ? -input : input; - break; - } - case ActivationFunction::Sqrt: - { - output = sqrtf(input); - break; - } - case ActivationFunction::Square: - { - output = input * input; - break; - } - case ActivationFunction::TanH: - { - output = a * tanhf(b * input); - break; - } - default: - { - BOOST_LOG_TRIVIAL(error) << "Unsupported activation function"; - return; - } - } - - out[i] = output; + out[i] = Activation(in[i], function, a, b); } } -- cgit v1.2.1