aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgiuros01 <giuseppe.rossini@arm.com>2019-05-14 16:12:53 +0100
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-06-11 10:38:21 +0000
commitd5134364fc4ca40ea65635192e7959327d690a01 (patch)
treed6781cc0319e54e538ea2b02ea59e842acfd6e49
parente7510622419a63315e5ad5ed7de61a2ce4bd0b49 (diff)
downloadComputeLibrary-d5134364fc4ca40ea65635192e7959327d690a01.tar.gz
COMPMID-2321: PRELU support in NEActivationLayer
Change-Id: Ib320ee7772492cd1b86eba624438da826d47b984 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1224 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--arm_compute/core/NEON/wrapper/traits.h42
-rw-r--r--arm_compute/runtime/NEON/NEFunctions.h1
-rw-r--r--arm_compute/runtime/NEON/functions/NEPReluLayer.h59
-rw-r--r--src/core/NEON/kernels/NEElementwiseOperationKernel.cpp84
-rw-r--r--src/runtime/NEON/functions/NEPReluLayer.cpp43
-rw-r--r--tests/validation/NEON/PReluLayer.cpp218
6 files changed, 398 insertions, 49 deletions
diff --git a/arm_compute/core/NEON/wrapper/traits.h b/arm_compute/core/NEON/wrapper/traits.h
index 0dbd90ddf8..cc22597c29 100644
--- a/arm_compute/core/NEON/wrapper/traits.h
+++ b/arm_compute/core/NEON/wrapper/traits.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,27 +44,27 @@ struct vector_128_tag {};
template <typename T, int S> struct neon_vector;
// Specializations
#ifndef DOXYGEN_SKIP_THIS
-template <> struct neon_vector<uint8_t, 8>{ using type = uint8x8_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<int8_t, 8>{ using type = int8x8_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<uint8_t, 16>{ using type = uint8x16_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<int8_t, 16>{ using type = int8x16_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<uint16_t, 4>{ using type = uint16x4_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<int16_t, 4>{ using type = int16x4_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<uint16_t, 8>{ using type = uint16x8_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<int16_t, 8>{ using type = int16x8_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<uint32_t, 2>{ using type = uint32x2_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<int32_t, 2>{ using type = int32x2_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<uint32_t, 4>{ using type = uint32x4_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<int32_t, 4>{ using type = int32x4_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<uint64_t, 1>{ using type = uint64x1_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<int64_t, 1>{ using type = int64x1_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<uint64_t, 2>{ using type = uint64x2_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<int64_t, 2>{ using type = int64x2_t; using tag_type = vector_128_tag; };
-template <> struct neon_vector<float_t, 2>{ using type = float32x2_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<float_t, 4>{ using type = float32x4_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<uint8_t, 8>{ using scalar_type = uint8_t; using type = uint8x8_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<int8_t, 8>{ using scalar_type = int8_t; using type = int8x8_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<uint8_t, 16>{ using scalar_type = uint8_t; using type = uint8x16_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<int8_t, 16>{ using scalar_type = int8_t; using type = int8x16_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<uint16_t, 4>{ using scalar_type = uint16_t; using type = uint16x4_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<int16_t, 4>{ using scalar_type = int16_t; using type = int16x4_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<uint16_t, 8>{ using scalar_type = uint16_t; using type = uint16x8_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<int16_t, 8>{ using scalar_type = int16_t; using type = int16x8_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<uint32_t, 2>{ using scalar_type = uint32_t; using type = uint32x2_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<int32_t, 2>{ using scalar_type = int32_t; using type = int32x2_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<uint32_t, 4>{ using scalar_type = uint32_t; using type = uint32x4_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<int32_t, 4>{ using scalar_type = int32_t; using type = int32x4_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<uint64_t, 1>{ using scalar_type = uint64_t;using type = uint64x1_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<int64_t, 1>{ using scalar_type = int64_t; using type = int64x1_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<uint64_t, 2>{ using scalar_type = uint64_t; using type = uint64x2_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<int64_t, 2>{ using scalar_type = int64_t; using type = int64x2_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<float_t, 2>{ using scalar_type = float_t; using type = float32x2_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<float_t, 4>{ using scalar_type = float_t; using type = float32x4_t; using tag_type = vector_128_tag; };
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <> struct neon_vector<float16_t, 4>{ using type = float16x4_t; using tag_type = vector_64_tag; };
-template <> struct neon_vector<float16_t, 8>{ using type = float16x8_t; using tag_type = vector_128_tag; };
+template <> struct neon_vector<float16_t, 4>{ using scalar_type = float16_t; using type = float16x4_t; using tag_type = vector_64_tag; };
+template <> struct neon_vector<float16_t, 8>{ using scalar_type = float16_t; using type = float16x8_t; using tag_type = vector_128_tag; };
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#endif /* DOXYGEN_SKIP_THIS */
diff --git a/arm_compute/runtime/NEON/NEFunctions.h b/arm_compute/runtime/NEON/NEFunctions.h
index 4e0cdd7a0a..94607364b3 100644
--- a/arm_compute/runtime/NEON/NEFunctions.h
+++ b/arm_compute/runtime/NEON/NEFunctions.h
@@ -106,6 +106,7 @@
#include "arm_compute/runtime/NEON/functions/NENonMaximaSuppression3x3.h"
#include "arm_compute/runtime/NEON/functions/NENormalizationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEOpticalFlow.h"
+#include "arm_compute/runtime/NEON/functions/NEPReluLayer.h"
#include "arm_compute/runtime/NEON/functions/NEPadLayer.h"
#include "arm_compute/runtime/NEON/functions/NEPermute.h"
#include "arm_compute/runtime/NEON/functions/NEPhase.h"
diff --git a/arm_compute/runtime/NEON/functions/NEPReluLayer.h b/arm_compute/runtime/NEON/functions/NEPReluLayer.h
new file mode 100644
index 0000000000..52db4279cd
--- /dev/null
+++ b/arm_compute/runtime/NEON/functions/NEPReluLayer.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_NEPRELULAYER_H__
+#define __ARM_COMPUTE_NEPRELULAYER_H__
+
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/INESimpleFunction.h"
+
+namespace arm_compute
+{
+class ITensor;
+
+/** Basic function to run @ref NEArithmeticOperationKernel for PRELU
+ *
+ * @note The function implements an activation layer with the PRELU activation function.
+ */
+class NEPReluLayer : public INESimpleFunction
+{
+public:
+ /** Set the input and output tensor.
+ *
+ * @param[in] input Source tensor. Data types supported: QASYMM8/F16/F32.
+ * @param[in] alpha Source alpha tensor. Data types supported: same of @p input.
+ * @param[out] output Destination tensor. Data type supported: same as @p input
+ */
+ void configure(const ITensor *input, const ITensor *alpha, ITensor *output);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEPReluLayer
+ *
+ * @param[in] input Source tensor info. Data types supported: QASYMM8/F16/F32.
+ * @param[in] alpha Source alpha tensor info. Data types supported: same of @p input.
+ * @param[in] output Destination tensor info. Data type supported: same as @p input
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *alpha, const ITensorInfo *output);
+};
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_NEPRELULAYER_H__ */
diff --git a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
index 0fe05d2044..8bd37d5913 100644
--- a/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEElementwiseOperationKernel.cpp
@@ -125,6 +125,11 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar
res = (a - b) * (a - b);
break;
}
+ case ArithmeticOperation::PRELU:
+ {
+ res = (a > 0 ? a : a * b);
+ break;
+ }
case ArithmeticOperation::DIV:
{
res = a / b;
@@ -147,10 +152,14 @@ inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const floa
return quantize_qasymm8(elementwise_arithm_op_scalar<op>(a, b), qinfo);
}
-template <ArithmeticOperation op, typename VectorType>
-inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b)
+template <ArithmeticOperation op, typename VectorType>
+inline typename VectorType::type elementwise_arithm_op(const typename VectorType::type &a, const typename VectorType::type &b)
{
- VectorType res = { 0, 0, 0, 0 };
+ using vec_type = typename VectorType::type;
+ using scalar_type = typename VectorType::scalar_type;
+ using tag_type = typename VectorType::tag_type;
+
+ vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
switch(op)
{
@@ -162,10 +171,20 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b
break;
case ArithmeticOperation::SQUARED_DIFF:
{
- const VectorType tmp = wrapper::vsub(a, b);
- res = wrapper::vmul(tmp, tmp);
+ const vec_type tmp = wrapper::vsub(a, b);
+ res = wrapper::vmul(tmp, tmp);
+ break;
+ }
+ case ArithmeticOperation::PRELU:
+ {
+ const vec_type zero = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{});
+ const vec_type tmp = wrapper::vmul(a, b);
+ const auto gt = wrapper::vcgt(a, zero);
+
+ res = wrapper::vbsl(gt, a, tmp);
break;
}
+
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}
@@ -174,26 +193,26 @@ inline VectorType elementwise_arithm_op(const VectorType &a, const VectorType &b
}
template <>
-inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, float32x4_t>(const float32x4_t &a, const float32x4_t &b)
+inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
{
return wrapper::vdiv(a, b);
}
template <>
-inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, float32x4_t>(const float32x4_t &a, const float32x4_t &b)
+inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b)
{
return wrapper::vpow(a, b);
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
-inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
+inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector_t<float16_t, 4>>(const float16x8_t &a, const float16x8_t &b)
{
return wrapper::vdiv(a, b);
}
template <>
-inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, float16x8_t>(const float16x8_t &a, const float16x8_t &b)
+inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector_t<float16_t, 4>>(const float16x8_t &a, const float16x8_t &b)
{
return wrapper::vpow(a, b);
}
@@ -202,23 +221,27 @@ inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, float16x8_t
template <ArithmeticOperation op>
inline float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b)
{
+ using neon_vector_float = wrapper::traits::neon_vector<float, 4>;
float32x4x4_t out =
{
{
- elementwise_arithm_op<op>(a.val[0], b.val[0]),
- elementwise_arithm_op<op>(a.val[1], b.val[1]),
- elementwise_arithm_op<op>(a.val[2], b.val[2]),
- elementwise_arithm_op<op>(a.val[3], b.val[3]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]),
+ elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]),
}
};
return out;
}
-template <ArithmeticOperation op, typename ScalarType, typename VectorType>
-inline VectorType elementwise_arithm_op_broadcast(const VectorType &a, const ScalarType &broadcast_value, const bool reorder)
+template <ArithmeticOperation op, typename ScalarType, typename VectorType>
+inline typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder)
{
- VectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag());
- return elementwise_arithm_op<op>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
+ using tag_type = typename VectorType::tag_type;
+ using vec_type = typename VectorType::type;
+
+ vec_type broadcast_vector = wrapper::vdup_n(broadcast_value, tag_type{});
+ return elementwise_arithm_op<op, VectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector);
}
template <ComparisonOperation op, typename InputScalarType>
@@ -322,7 +345,7 @@ inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int
{
const auto a = wrapper::vloadq(input1_ptr + x);
const auto b = wrapper::vloadq(input2_ptr + x);
- wrapper::vstore(output_ptr + x, elementwise_arithm_op<op>(a, b));
+ wrapper::vstore(output_ptr + x, elementwise_arithm_op<op, VectorType>(a, b));
}
return x;
}
@@ -353,7 +376,7 @@ inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_e
for(; x <= (window_end_x - window_step_x); x += window_step_x)
{
const auto a = wrapper::vloadq((non_broadcast_input_ptr + x));
- wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op>(a, broadcast_value, reorder));
+ wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder));
}
return x;
}
@@ -692,13 +715,15 @@ void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out
&elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>);
}
-template <ArithmeticOperation op, typename ScalarType, typename VectorType>
+template <ArithmeticOperation op, typename VectorType>
void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- elementwise_op<ScalarType, ScalarType, VectorType>(in1, in2, out, window,
- &elementwise_arithm_op_scalar<op, ScalarType>,
- &elementwise_arithm_op_broadcast_loop<op, ScalarType, VectorType>,
- &elementwise_arithm_op_loop<op, ScalarType, VectorType>);
+ using scalar_type = typename VectorType::scalar_type;
+
+ elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window,
+ &elementwise_arithm_op_scalar<op, scalar_type>,
+ &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>,
+ &elementwise_arithm_op_loop<op, scalar_type, VectorType>);
}
template <ArithmeticOperation op>
@@ -745,13 +770,13 @@ configure_arithm_func(const ITensor *input1, const ITensor *input2, ITensor *out
{
static std::map<std::string, NEElementwiseOperationKernel::ElementwiseFunction *> map_function =
{
- { "op_F32_F32_F32", &elementwise_arithm_op<op, float, float32x4_t> },
- { "op_S16_S16_S16", &elementwise_arithm_op<op, int16_t, int16x8_t> },
- { "op_S32_S32_S32", &elementwise_arithm_op<op, int32_t, int32x4_t> },
+ { "op_F32_F32_F32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float, 4>> },
+ { "op_S16_S16_S16", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int16_t, 8>> },
+ { "op_S32_S32_S32", &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<int32_t, 4>> },
{ "op_QASYMM8_QASYMM8_QASYMM8", &elementwise_arithm_op_quantized<op> }
};
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, float16_t, float16x8_t>;
+ map_function["op_F16_F16_F16"] = &elementwise_arithm_op<op, typename wrapper::traits::neon_vector<float16_t, 8>>;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
return configure_func(input1, input2, output, map_function);
@@ -849,6 +874,9 @@ void NEArithmeticOperationKernel::configure(ArithmeticOperation op, const ITenso
case ArithmeticOperation::SQUARED_DIFF:
_function = configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
break;
+ case ArithmeticOperation::PRELU:
+ _function = configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
+ break;
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}
diff --git a/src/runtime/NEON/functions/NEPReluLayer.cpp b/src/runtime/NEON/functions/NEPReluLayer.cpp
new file mode 100644
index 0000000000..b386fed575
--- /dev/null
+++ b/src/runtime/NEON/functions/NEPReluLayer.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEPReluLayer.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/NEElementwiseOperationKernel.h"
+#include "support/ToolchainSupport.h"
+
+namespace arm_compute
+{
+void NEPReluLayer::configure(const ITensor *input, const ITensor *alpha, ITensor *output)
+{
+ auto k = arm_compute::support::cpp14::make_unique<NEArithmeticOperationKernel>();
+ k->configure(ArithmeticOperation::PRELU, input, alpha, output);
+ _kernel = std::move(k);
+}
+
+Status NEPReluLayer::validate(const ITensorInfo *input, const ITensorInfo *alpha, const ITensorInfo *output)
+{
+ return NEArithmeticOperationKernel::validate(ArithmeticOperation::PRELU, input, alpha, output);
+}
+} // namespace arm_compute
diff --git a/tests/validation/NEON/PReluLayer.cpp b/tests/validation/NEON/PReluLayer.cpp
new file mode 100644
index 0000000000..95dbf33393
--- /dev/null
+++ b/tests/validation/NEON/PReluLayer.cpp
@@ -0,0 +1,218 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/NEON/functions/NEPReluLayer.h"
+#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/runtime/TensorAllocator.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/PaddingCalculator.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/ElementwiseOperationsFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+RelativeTolerance<float> tolerance_fp32(0.000001f);
+
+constexpr unsigned int num_elems_processed_per_iteration = 16;
+/** Input data sets **/
+const auto PReluLayerQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataType",
+ DataType::QASYMM8));
+const auto PReluLayerFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataType", DataType::F32));
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+RelativeTolerance<float> tolerance_fp16(0.001f);
+
+const auto PReluLayerFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("DataType", DataType::F16));
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+} // namespace
+
+TEST_SUITE(NEON)
+TEST_SUITE(PReluLayer)
+
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ }),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ })),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ })),
+ framework::dataset::make("Expected", { true, true, false, false, false})),
+ input1_info, input2_info, output_info, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(NEPReluLayer::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
+template <typename T>
+using NEPReluLayerFixture = PReluLayerValidationFixture<Tensor, Accessor, NEPReluLayer, T>;
+
+template <typename T>
+using NEPReluLayerQuantizedFixture = PReluLayerValidationQuantizedFixture<Tensor, Accessor, NEPReluLayer, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8)
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, datasets::SmallShapes(),
+ shape)
+{
+ // Create tensors
+ Tensor ref_src1 = create_tensor<Tensor>(shape, DataType::QASYMM8);
+ Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::QASYMM8);
+ Tensor dst = create_tensor<Tensor>(shape, DataType::QASYMM8);
+
+ // Create and Configure function
+ NEPReluLayer prelu;
+ prelu.configure(&ref_src1, &ref_src2, &dst);
+
+ // Validate valid region
+ const ValidRegion valid_region = shape_to_valid_region(shape);
+ validate(dst.info()->valid_region(), valid_region);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPReluLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(),
+ PReluLayerQASYMM8Dataset),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255.f, 5) }))
+
+ )
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPReluLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
+ PReluLayerQASYMM8Dataset),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255.f, 5) }))
+
+ )
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+TEST_SUITE(Float)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPReluLayerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), PReluLayerFP16Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp16, 0.01);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPReluLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), PReluLayerFP16Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp16, 0.01);
+}
+TEST_SUITE_END() // FP16
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+TEST_SUITE(FP32)
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, datasets::SmallShapes(),
+ shape)
+{
+ // Create tensors
+ Tensor ref_src1 = create_tensor<Tensor>(shape, DataType::F32);
+ Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::F32);
+ Tensor dst = create_tensor<Tensor>(shape, DataType::F32);
+
+ // Create and Configure function
+ NEPReluLayer prelu;
+ prelu.configure(&ref_src1, &ref_src2, &dst);
+
+ // Validate valid region
+ const ValidRegion valid_region = shape_to_valid_region(shape);
+ validate(dst.info()->valid_region(), valid_region);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPReluLayerFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), PReluLayerFP32Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPReluLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), PReluLayerFP32Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32);
+}
+
+template <typename T>
+using NEPReluLayerBroadcastFixture = PReluLayerBroadcastValidationFixture<Tensor, Accessor, NEPReluLayer, T>;
+
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEPReluLayerBroadcastFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallShapesBroadcast(),
+ PReluLayerFP32Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEPReluLayerBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapesBroadcast(),
+ PReluLayerFP32Dataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32);
+}
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+
+TEST_SUITE_END() // PReluLayer
+TEST_SUITE_END() // NEON
+} // namespace validation
+} // namespace test
+} // namespace arm_compute