ArmNN
 20.02
Activation.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Activation.hpp"
7 
8 #include <cmath>
9 
10 namespace armnn
11 {
12 
13 float Activation(float in,
14  ActivationFunction function,
15  float a,
16  float b)
17 {
18  float output;
19 
20  // Compute the result of the activation function.
21  switch (function)
22  {
24  {
25  output = a * in + b;
26  break;
27  }
29  {
30  output = 1.f / (1.f + expf(-in));
31  break;
32  }
34  {
35  output = std::max(0.f, in);
36  break;
37  }
39  {
40  output = std::min(a, std::max(b, in));
41  break;
42  }
44  {
45  output = logf(1.0f + expf(in));
46  break;
47  }
49  {
50  output = in > 0.0f ? in : (in * a);
51  break;
52  }
54  {
55  output = in < 0 ? -in : in;
56  break;
57  }
59  {
60  output = sqrtf(in);
61  break;
62  }
64  {
65  output = in * in;
66  break;
67  }
69  {
70  output = a * tanhf(b * in);
71  break;
72  }
74  {
75  output = (in >= 0) ? in : a * (expf(in) - 1);
76  break;
77  }
79  {
80  // hard_swish(x) = x * relu6(x+3) / 6
81  // relu6(x) = min(max(x,0),6)
82  output = in * (std::min(std::max((in + 3),0.0f),6.0f)) / 6;
83  break;
84  }
85  default:
86  {
87  throw InvalidArgumentException("Unsupported activation function");
88  }
89  }
90 
91  return output;
92 }
93 
94 
96  Encoder<float>& out,
97  const TensorInfo& tensorInfo,
98  ActivationFunction function,
99  float a,
100  float b)
101 {
102  unsigned int numElements = tensorInfo.GetNumElements();
103 
104  for (unsigned int i = 0; i < numElements; i++)
105  {
106  out.Set(Activation(in.Get(), function, a, b));
107  ++in;
108  ++out;
109  }
110  in -= numElements;
111  out -= numElements;
112 }
113 
114 } //namespace armnn
virtual void Set(IType right)=0
Copyright (c) 2020 ARM Limited.
virtual IType Get() const =0
float Activation(float in, ActivationFunction function, float a, float b)
Definition: Activation.cpp:13
min(a, max(b, input)) ReLu1 & ReLu6.
unsigned int GetNumElements() const
Definition: Tensor.hpp:93
ActivationFunction
Definition: Types.hpp:55