diff options
Diffstat (limited to 'src/backends/reference/workloads/Activation.cpp')
-rw-r--r-- | src/backends/reference/workloads/Activation.cpp | 150 |
1 files changed, 86 insertions, 64 deletions
diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp index ef4903074b..760c9a0ccd 100644 --- a/src/backends/reference/workloads/Activation.cpp +++ b/src/backends/reference/workloads/Activation.cpp @@ -11,6 +11,91 @@ namespace armnn { +float Activation(float in, + ActivationFunction function, + float a, + float b) +{ + float output; + + // Compute the result of the activation function. + switch (function) + { + case ActivationFunction::Linear: + { + output = a * in + b; + break; + } + case ActivationFunction::Sigmoid: + { + output = 1.f / (1.f + expf(-in)); + break; + } + case ActivationFunction::ReLu: + { + output = std::max(0.f, in); + break; + } + case ActivationFunction::BoundedReLu: + { + output = std::min(a, std::max(b, in)); + break; + } + case ActivationFunction::SoftReLu: + { + output = logf(1.0f + expf(in)); + break; + } + case ActivationFunction::LeakyReLu: + { + output = in > 0.0f ? in : (in * a); + break; + } + case ActivationFunction::Abs: + { + output = in < 0 ? -in : in; + break; + } + case ActivationFunction::Sqrt: + { + output = sqrtf(in); + break; + } + case ActivationFunction::Square: + { + output = in * in; + break; + } + case ActivationFunction::TanH: + { + output = a * tanhf(b * in); + break; + } + default: + { + throw InvalidArgumentException("Unsupported activation function"); + } + } + + return output; +} + + +void Activation(Decoder<float>& in, + Encoder<float>& out, + const TensorInfo& tensorInfo, + ActivationFunction function, + float a, + float b) +{ + for (size_t i = 0; i<tensorInfo.GetNumElements(); i++) + { + out.Set(Activation(in.Get(), function, a, b)); + + ++in; + ++out; + } +} void Activation(const float* in, float* out, @@ -21,70 +106,7 @@ void Activation(const float* in, { for (size_t i = 0; i<tensorInfo.GetNumElements(); i++) { - float input = in[i]; - float output; - - // Compute the result of the activation function. - switch (function) - { - case ActivationFunction::Linear: - { - output = a * input + b; - break; - } - case ActivationFunction::Sigmoid: - { - output = 1.f / (1.f + expf(-input)); - break; - } - case ActivationFunction::ReLu: - { - output = std::max(0.f, input); - break; - } - case ActivationFunction::BoundedReLu: - { - output = std::min(a, std::max(b, input)); - break; - } - case ActivationFunction::SoftReLu: - { - output = logf(1.0f + expf(input)); - break; - } - case ActivationFunction::LeakyReLu: - { - output = input > 0.0f ? input : (input * a); - break; - } - case ActivationFunction::Abs: - { - output = input < 0 ? -input : input; - break; - } - case ActivationFunction::Sqrt: - { - output = sqrtf(input); - break; - } - case ActivationFunction::Square: - { - output = input * input; - break; - } - case ActivationFunction::TanH: - { - output = a * tanhf(b * input); - break; - } - default: - { - BOOST_LOG_TRIVIAL(error) << "Unsupported activation function"; - return; - } - } - - out[i] = output; + out[i] = Activation(in[i], function, a, b); } } |