aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
diff options
context:
space:
mode:
authorgiuros01 <giuseppe.rossini@arm.com>2019-05-14 16:12:53 +0100
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-06-11 10:38:21 +0000
commitd5134364fc4ca40ea65635192e7959327d690a01 (patch)
treed6781cc0319e54e538ea2b02ea59e842acfd6e49 /src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
parente7510622419a63315e5ad5ed7de61a2ce4bd0b49 (diff)
downloadComputeLibrary-d5134364fc4ca40ea65635192e7959327d690a01.tar.gz
COMPMID-2321: PRELU support in NEActivationLayer
Change-Id: Ib320ee7772492cd1b86eba624438da826d47b984 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1224 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEElementwiseOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEElementwiseOperationKernel.cpp84
1 files changed, 56 insertions, 28 deletions
diff --git a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
index 0fe05d2044..8bd37d5913 100644
--- a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
@@ -125,6 +125,11 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar
res = (a - b) * (a - b);
break;
}
+ case ArithmeticOperation::PRELU:
+ {
+ res = (a > 0 ? a : a * b);
+ break;
+ }
case ArithmeticOperation::DIV:
{
res = a / b;
@@ -147,10 +152,14 @@ inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const floa
return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
}
-template <ArithmeticOperation op, typename VectorType>
-inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b)
+template <ArithmeticOperation op, typename VectorType>
+inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
{
- VectorType res = { 0, 0, 0, 0 };
+ using vec_type = typename VectorType::type;
+ using scalar_type = typename VectorType::scalar_type;
+ using tag_type = typename VectorType::tag_type;
+
+ vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
switch(op)
{
@@ -162,10 +171,20 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b
break;
case ArithmeticOperation::SQUARED_DIFF:
{
- const VectorType tmp = wrapper::vsub(a, b);
- res = wrapper::vmul(tmp, tmp);
+ const vec_type tmp = wrapper::vsub(a, b);
+ res = wrapper::vmul(tmp, tmp);
+ break;
+ }
+ case ArithmeticOperation::PRELU:
+ {
+ const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
+ const vec_type tmp = wrapper::vmul(a, b);
+ const auto gt = wrapper::vcgt(a, zero);
+
+ res = wrapper::vbsl(gt, a, tmp);
break;
}
+
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}
@@ -174,26 +193,26 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b
}
template <>
-inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, float32x4_t>(const float32x4_t &a, const float32x4_t &b)
+inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
{
return wrapper::vdiv(a, b);
}
template <>
-inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, float32x4_t>(const float32x4_t &a, const float32x4_t &b)
+inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
{
return wrapper::vpow(a, b);
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
-inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
+inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector_t<float16_t, 4>>(const float16x8_t &a, const float16x8_t &b)
{
return wrapper::vdiv(a, b);
}
template <>
-inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
+inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector_t<float16_t, 4>>(const float16x8_t &a, const float16x8_t &b)
{
return wrapper::vpow(a, b);
}
@@ -202,23 +221,27 @@ inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, float16x8_t
template <ArithmeticOperation op>
inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
{
+ using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
float32x4x4_t out =
{
{
- elementwise_arithm_op<op>(a.val[0], b.val[0]),
- elementwise_arithm_op<op>(a.val[1], b.val[1]),
- elementwise_arithm_op<op>(a.val[2], b.val[2]),
- elementwise_arithm_op<op>(a.val[3], b.val[3]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
}
};
return out;
}
-template <ArithmeticOperation op, typename ScalarType, typename VectorType>
-inline VectorType elementwise_arithm_op_broadcast(const VectorType &a, const ScalarType &broadcast_value, const bool reorder)
+template <ArithmeticOperation op, typename ScalarType, typename VectorType>
+inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
{
- VectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
- return elementwise_arithm_op<op>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
+ using tag_type = typename VectorType::tag_type;
+ using vec_type = typename VectorType::type;
+
+ vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
+ return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
}
template <ComparisonOperation op, typename InputScalarType>
@@ -322,7 +345,7 @@ inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int
{
const auto a = wrapper::vloadq(input1_ptr + x);
const auto b = wrapper::vloadq(input2_ptr + x);
- wrapper::vstore(output_ptr + x, elementwise_arithm_op<op>(a, b));
+ wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
}
return x;
}
@@ -353,7 +376,7 @@ inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_e
for(; x <= (window_end_x - window_step_x); x += window_step_x)
{
const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
- wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op>(a, broadcast_value, reorder));
+ wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
}
return x;
}
@@ -692,13 +715,15 @@ void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out
&elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
}
-template <ArithmeticOperation op, typename ScalarType, typename VectorType>
+template <ArithmeticOperation op, typename VectorType>
void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- elementwise_op<ScalarType, ScalarType, VectorType>(in1, in2, out, window,
- &elementwise_arithm_op_scalar<op, ScalarType>,
- &elementwise_arithm_op_broadcast_loop<op, ScalarType, VectorType>,
- &elementwise_arithm_op_loop<op, ScalarType, VectorType>);
+ using scalar_type = typename VectorType::scalar_type;
+
+ elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
+ &elementwise_arithm_op_scalar<op, scalar_type>,
+ &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
+ &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
}
template <ArithmeticOperation op>
@@ -745,13 +770,13 @@ configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *out
{
static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
{
- { "op_F32_F32_F32", &elementwise_arithm_op<op, float, float32x4_t> },
- { "op_S16_S16_S16", &elementwise_arithm_op<op, int16_t, int16x8_t> },
- { "op_S32_S32_S32", &elementwise_arithm_op<op, int32_t, int32x4_t> },
+ { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
+ { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
+ { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
{ "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> }
};
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, float16_t, float16x8_t>;
+ map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
return configure_func(input1, input2, output, map_function);
@@ -849,6 +874,9 @@ void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITenso
case ArithmeticOperation::SQUARED_DIFF:
_function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
break;
+ case ArithmeticOperation::PRELU:
+ _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
+ break;
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}