aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEActivationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEActivationLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEActivationLayerKernel.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
index 8de8db9ad9..bc6a281353 100644
--- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
@@ -118,6 +118,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
{ ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, float> },
{ ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float> },
{ ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, float> },
+ { ActivationFunction::IDENTITY, &NEActivationLayerKernel::activation<ActivationFunction::IDENTITY, float> },
};
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -135,6 +136,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
{ ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, float16_t> },
{ ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float16_t> },
{ ActivationFunction::TANH, &NEActivationLayerKernel::activation<ActivationFunction::TANH, float16_t> },
+ { ActivationFunction::IDENTITY, &NEActivationLayerKernel::activation<ActivationFunction::IDENTITY, float16_t> },
};
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
@@ -145,6 +147,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat
{ ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, qasymm8_t> },
{ ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LU_BOUNDED_RELU, qasymm8_t> },
{ ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, qasymm8_t> },
+ { ActivationFunction::IDENTITY, &NEActivationLayerKernel::activation<ActivationFunction::IDENTITY, qasymm8_t> },
};
switch(input->info()->data_type())
@@ -242,6 +245,9 @@ NEActivationLayerKernel::activation(const Window &window)
case ActivationFunction::TANH:
tmp = wrapper::vmul(va, wrapper::vtanh(wrapper::vmul(vb, vin)));
break;
+ case ActivationFunction::IDENTITY:
+ tmp = vin;
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported activation function");
}
@@ -288,6 +294,9 @@ NEActivationLayerKernel::activation(const Window &window)
case ActivationFunction::TANH:
tmp = a * std::tanh(b * in);
break;
+ case ActivationFunction::IDENTITY:
+ tmp = in;
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported activation function");
}