aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Android.bp2
-rw-r--r--filelist.json7
-rw-r--r--src/BUILD.bazel2
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp138
-rw-r--r--src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h36
-rw-r--r--src/core/NEON/kernels/batchnormalization/impl/list.h25
-rw-r--r--src/cpu/kernels/CpuPool2dKernel.cpp2
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/generic/impl.h118
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp96
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp93
11 files changed, 361 insertions, 160 deletions
diff --git a/Android.bp b/Android.bp
index 31ec9b2716..23b264a3a9 100644
--- a/Android.bp
+++ b/Android.bp
@@ -515,6 +515,8 @@ cc_library_static {
"src/cpu/kernels/fuse_batch_normalization/generic/fp16.cpp",
"src/cpu/kernels/fuse_batch_normalization/generic/fp32.cpp",
"src/cpu/kernels/fuse_batch_normalization/nchw/all.cpp",
+ "src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp",
+ "src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp",
"src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp",
"src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp",
"src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp",
diff --git a/filelist.json b/filelist.json
index a84db7188d..c34eff2ff9 100644
--- a/filelist.json
+++ b/filelist.json
@@ -982,12 +982,15 @@
"fp16": [
"src/cpu/kernels/fuse_batch_normalization/generic/fp16.cpp",
"src/core/NEON/kernels/batchnormalization/impl/NEON/fp16.cpp",
- "src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp"
+ "src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp",
+ "src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp"
+
],
"fp32": [
"src/cpu/kernels/fuse_batch_normalization/generic/fp32.cpp",
"src/core/NEON/kernels/batchnormalization/impl/NEON/fp32.cpp",
- "src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp"
+ "src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp",
+ "src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp"
]
},
"sve": {
diff --git a/src/BUILD.bazel b/src/BUILD.bazel
index 42841fe28d..f281b6a4d5 100644
--- a/src/BUILD.bazel
+++ b/src/BUILD.bazel
@@ -766,6 +766,8 @@ filegroup(
"cpu/kernels/fuse_batch_normalization/generic/fp16.cpp",
"cpu/kernels/fuse_batch_normalization/generic/fp32.cpp",
"cpu/kernels/fuse_batch_normalization/nchw/all.cpp",
+ "cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp",
+ "cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp",
"cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp",
"cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp",
"cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 1de9e63737..3229ffa8c2 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -757,6 +757,8 @@ target_sources(
cpu/kernels/fuse_batch_normalization/generic/fp16.cpp
cpu/kernels/fuse_batch_normalization/generic/fp32.cpp
cpu/kernels/fuse_batch_normalization/nchw/all.cpp
+ cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp
+ cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp
cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp
cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp
cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index deb89996a9..717fd11485 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -151,128 +151,15 @@ Status validate_arguments(const ITensorInfo *input,
}
} //namespace
-template <typename T, bool fused_activation, typename F>
-void NEBatchNormalizationLayerKernel::batch_normalization_nchw(const Window &window)
-{
- /** SIMD vector tag type. */
- using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
- const int window_step_x = 16 / sizeof(T);
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- Window win_to_use = window;
- win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input(_input, win_to_use);
- Iterator output(_output, win_to_use);
-
- F activation_functor(_act_info);
-
- // Hold information about the current feature map we are iterating.
- // Only compute denominator and constants once per feature map.
- int slice = -1;
-
- const auto input_mean = reinterpret_cast<const T *>(_mean->ptr_to_element(Coordinates(0, 0)));
- const auto input_var = reinterpret_cast<const T *>(_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma =
- (_gamma != nullptr) ? reinterpret_cast<const T *>(_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta =
- (_beta != nullptr) ? reinterpret_cast<const T *>(_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
-
- T mean = static_cast<T>(0);
- T var = static_cast<T>(0);
- T gamma = static_cast<T>(1);
- T beta = static_cast<T>(0);
- T denominator = static_cast<T>(0);
-
- auto mean_vec = wrapper::vdup_n(mean, ExactTagType{});
- auto var_vec = wrapper::vdup_n(var, ExactTagType{});
- auto gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
- auto beta_vec = wrapper::vdup_n(beta, ExactTagType{});
- auto denominator_vec = wrapper::vdup_n(denominator, ExactTagType{});
- const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(_epsilon), ExactTagType{});
- execute_window_loop(
- win_to_use,
- [&](const Coordinates &id)
- {
- const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
- const auto output_ptr = reinterpret_cast<T *>(output.ptr());
-
- if (slice != id.z())
- {
- mean = input_mean[id.z()];
- var = input_var[id.z()];
- mean_vec = wrapper::vdup_n(mean, ExactTagType{});
- var_vec = wrapper::vdup_n(var, ExactTagType{});
- if (input_gamma != nullptr)
- {
- gamma = input_gamma[id.z()];
- gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
- }
- if (input_beta != nullptr)
- {
- beta = input_beta[id.z()];
- beta_vec = wrapper::vdup_n(beta, ExactTagType{});
- }
-
- // Calculate denominator
- denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
- denominator = wrapper::vgetlane(denominator_vec, 0);
- slice = id.z();
- }
-
- // Perform core calculations using vector operations
- int x = window_start_x;
- for (; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- // Calculate x bar
- const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
- const auto x_bar = wrapper::vmul(numerator, denominator_vec);
- auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
-
- // Perform fused activation
- if (fused_activation)
- {
- activation_functor(res);
- }
-
- // Store results
- wrapper::vstore(output_ptr + x, res);
- }
-
- // Compute left-over elements
- for (; x < window_end_x; ++x)
- {
- const T numerator = input_ptr[x] - mean;
- const T x_bar = numerator * denominator;
- T res = beta + x_bar * gamma;
-
- // Perform fused activation
- if (fused_activation)
- {
- activation_functor(res);
- }
-
- // Store results
- *(output_ptr + x) = res;
- }
- },
- input, output);
-}
-
void NEBatchNormalizationLayerKernel::configure_non_fused()
{
switch (_input->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, false,
- detail::dummy<float16_t, 8>>;
+ _func = REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused);
break;
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
- _func = &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, false, detail::dummy<float, 4>>;
+ _func = REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused);
break;
default:
ARM_COMPUTE_ERROR("Element size not supported");
@@ -285,29 +172,26 @@ void NEBatchNormalizationLayerKernel::configure_fused()
// NCHW Fused Batched Normalization with activation functions : FP32
static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f32_nchw = {
{ActivationLayerInfo::ActivationFunction::RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::relu<float, 4>>},
+ REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_relu)},
{ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::brelu<float, 4>>},
+ REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_brelu)},
{ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float, true, detail::lubrelu<float, 4>>}};
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ REGISTER_FP32_NEON(cpu::fp32_batch_normalization_nchw_non_fused_lubrelu)}};
+
// NCHW Fused Batched Normalization with activation functions : FP16
static std::map<ActivationLayerInfo::ActivationFunction, BatchNormFunctionPtr> bn_fused_map_f16_nchw = {
{ActivationLayerInfo::ActivationFunction::RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::relu<float16_t, 8>>},
+ REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_relu)},
{ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::brelu<float16_t, 8>>},
+ REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_brelu)},
{ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
- &NEBatchNormalizationLayerKernel::batch_normalization_nchw<float16_t, true, detail::lubrelu<float16_t, 8>>}};
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ REGISTER_FP16_NEON(cpu::fp16_batch_normalization_nchw_non_fused_lubrelu)}};
switch (_input->info()->data_type())
{
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
_func = bn_fused_map_f16_nchw[_act_info.activation()];
break;
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F32:
_func = bn_fused_map_f32_nchw[_act_info.activation()];
break;
@@ -409,7 +293,7 @@ void NEBatchNormalizationLayerKernel::run(const Window &window, const ThreadInfo
const bool is_nchw = _input->info()->data_layout() == DataLayout::NCHW;
if (is_nchw)
{
- (this->*_func)(window);
+ (*_func)(window, _input, _output, _mean, _var, _beta, _gamma, _epsilon, _act_info);
}
else
{
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
index 2e8ff0dc9a..679ade0fae 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H
-#define ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H
+#ifndef ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H
+#define ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H
#include "arm_compute/function_info/ActivationLayerInfo.h"
@@ -110,31 +110,19 @@ private:
/** Configure execution function in case of fused activation **/
void configure_fused();
- /** Template function to run batch normalization on fp32
- *
- * @tparam T Specialization data type
- * @tparam fused_activation Boolean that flags if its a fused activation or not
- * @tparam F Activation function functor to run
- *
- * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
- */
- template <typename T, bool fused_activation, typename F>
- void batch_normalization_nchw(const Window &window);
- /** Template function to run batch normalization on fp32 on tensors with NHWC format
- *
- * @tparam T Specialization data type
- * @tparam fused_activation Boolean that flags if its a fused activation or not
- * @tparam F Activation function functor to run
- *
- * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
- */
- template <typename T, bool fused_activation, typename F>
- void batch_normalization_nhwc(const Window &window);
/** Common signature for all the batch normalization functions
*
* @param[in] window Region on which to execute the kernel.
*/
- using BatchNormFunctionPtr = void (NEBatchNormalizationLayerKernel::*)(const Window &window);
+ using BatchNormFunctionPtr = void (*)(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info);
private:
BatchNormFunctionPtr _func;
@@ -148,4 +136,4 @@ private:
ActivationLayerInfo _act_info;
};
} // namespace arm_compute
-#endif /*ARM_COMPUTE_NEBATCHNORMALIZATIONLAYERKERNEL_H */
+#endif // ACL_SRC_CORE_NEON_KERNELS_NEBATCHNORMALIZATIONLAYERKERNEL_H
diff --git a/src/core/NEON/kernels/batchnormalization/impl/list.h b/src/core/NEON/kernels/batchnormalization/impl/list.h
index cbf540bd71..c619788125 100644
--- a/src/core/NEON/kernels/batchnormalization/impl/list.h
+++ b/src/core/NEON/kernels/batchnormalization/impl/list.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_CORE_NEON_KERNELS_BATCH_NORMALIZATION_LIST_H
-#define SRC_CORE_NEON_KERNELS_BATCH_NORMALIZATION_LIST_H
+#ifndef ACL_SRC_CORE_NEON_KERNELS_BATCHNORMALIZATION_IMPL_LIST_H
+#define ACL_SRC_CORE_NEON_KERNELS_BATCHNORMALIZATION_IMPL_LIST_H
namespace arm_compute
{
@@ -37,8 +37,23 @@ DECLARE_BATCH_NORMALIZATION_KERNEL(fp16_sve_batch_normalization);
DECLARE_BATCH_NORMALIZATION_KERNEL(fp32_neon_batch_normalization);
DECLARE_BATCH_NORMALIZATION_KERNEL(fp32_sve_batch_normalization);
-#undef DECLARE_ACTIVATION_KERNEL
+#define DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(func_name) \
+ void func_name(const Window &window, ITensor *input, ITensor *output, const ITensor *mean, const ITensor *var, \
+ const ITensor *beta, const ITensor *gamma, float epsilon, ActivationLayerInfo act_info)
+
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp16_batch_normalization_nchw_non_fused);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp32_batch_normalization_nchw_non_fused);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp16_batch_normalization_nchw_non_fused_relu);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp16_batch_normalization_nchw_non_fused_brelu);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp16_batch_normalization_nchw_non_fused_lubrelu);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp32_batch_normalization_nchw_non_fused_relu);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp32_batch_normalization_nchw_non_fused_brelu);
+DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL(fp32_batch_normalization_nchw_non_fused_lubrelu);
+
+#undef DECLARE_BATCH_NORMALIZATION_KERNEL
+#undef DECLARE_BATCH_NORMALIZATION_NCHW_KERNEL
+
} // namespace cpu
} // namespace arm_compute
-#endif /* SRC_CORE_NEON_KERNELS_BATCH_NORMALIZATION_LIST_H */
+#endif // ACL_SRC_CORE_NEON_KERNELS_BATCHNORMALIZATION_IMPL_LIST_H
diff --git a/src/cpu/kernels/CpuPool2dKernel.cpp b/src/cpu/kernels/CpuPool2dKernel.cpp
index 9308d860d1..2c9627bdee 100644
--- a/src/cpu/kernels/CpuPool2dKernel.cpp
+++ b/src/cpu/kernels/CpuPool2dKernel.cpp
@@ -271,11 +271,9 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *
break;
}
break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
num_elems_processed_per_iteration = 1;
break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
num_elems_processed_per_iteration = 1;
break;
diff --git a/src/cpu/kernels/fuse_batch_normalization/generic/impl.h b/src/cpu/kernels/fuse_batch_normalization/generic/impl.h
index d807148e37..0c90abccb1 100644
--- a/src/cpu/kernels/fuse_batch_normalization/generic/impl.h
+++ b/src/cpu/kernels/fuse_batch_normalization/generic/impl.h
@@ -32,6 +32,124 @@ namespace arm_compute
{
namespace cpu
{
+template <typename T, bool fused_activation, typename F>
+void batch_normalization_nchw(const Window &window,
+ ITensor *in,
+ ITensor *out,
+ const ITensor *in_mean,
+ const ITensor *in_var,
+ const ITensor *in_beta,
+ const ITensor *in_gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ const int window_step_x = 16 / sizeof(T);
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ Window win_to_use = window;
+ win_to_use.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(in, win_to_use);
+ Iterator output(out, win_to_use);
+
+ F activation_functor(act_info);
+
+ // Hold information about the current feature map we are iterating.
+ // Only compute denominator and constants once per feature map.
+ int slice = -1;
+
+ const auto input_mean = reinterpret_cast<const T *>(in_mean->ptr_to_element(Coordinates(0, 0)));
+ const auto input_var = reinterpret_cast<const T *>(in_var->ptr_to_element(Coordinates(0, 0)));
+ const auto input_gamma =
+ (in_gamma != nullptr) ? reinterpret_cast<const T *>(in_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
+ const auto input_beta =
+ (in_beta != nullptr) ? reinterpret_cast<const T *>(in_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
+
+ T mean = static_cast<T>(0);
+ T var = static_cast<T>(0);
+ T gamma = static_cast<T>(1);
+ T beta = static_cast<T>(0);
+ T denominator = static_cast<T>(0);
+
+ auto mean_vec = wrapper::vdup_n(mean, ExactTagType{});
+ auto var_vec = wrapper::vdup_n(var, ExactTagType{});
+ auto gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
+ auto beta_vec = wrapper::vdup_n(beta, ExactTagType{});
+ auto denominator_vec = wrapper::vdup_n(denominator, ExactTagType{});
+ const auto epsilon_vec = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
+ execute_window_loop(
+ win_to_use,
+ [&](const Coordinates &id)
+ {
+ const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
+ const auto output_ptr = reinterpret_cast<T *>(output.ptr());
+
+ if (slice != id.z())
+ {
+ mean = input_mean[id.z()];
+ var = input_var[id.z()];
+ mean_vec = wrapper::vdup_n(mean, ExactTagType{});
+ var_vec = wrapper::vdup_n(var, ExactTagType{});
+ if (input_gamma != nullptr)
+ {
+ gamma = input_gamma[id.z()];
+ gamma_vec = wrapper::vdup_n(gamma, ExactTagType{});
+ }
+ if (input_beta != nullptr)
+ {
+ beta = input_beta[id.z()];
+ beta_vec = wrapper::vdup_n(beta, ExactTagType{});
+ }
+
+ // Calculate denominator
+ denominator_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
+ denominator = wrapper::vgetlane(denominator_vec, 0);
+ slice = id.z();
+ }
+
+ // Perform core calculations using vector operations
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Calculate x bar
+ const auto numerator = wrapper::vsub(wrapper::vloadq(input_ptr + x), mean_vec);
+ const auto x_bar = wrapper::vmul(numerator, denominator_vec);
+ auto res = wrapper::vmla(beta_vec, x_bar, gamma_vec);
+
+ // Perform fused activation
+ if (fused_activation)
+ {
+ activation_functor(res);
+ }
+
+ // Store results
+ wrapper::vstore(output_ptr + x, res);
+ }
+
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
+ {
+ const T numerator = input_ptr[x] - mean;
+ const T x_bar = numerator * denominator;
+ T res = beta + x_bar * gamma;
+
+ // Perform fused activation
+ if (fused_activation)
+ {
+ activation_functor(res);
+ }
+
+ // Store results
+ *(output_ptr + x) = res;
+ }
+ },
+ input, output);
+}
+
template <typename T>
void fused_batch_normalization_conv(const ITensor *conv_weights,
const ITensor *conv_bias,
diff --git a/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp b/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp
new file mode 100644
index 0000000000..ae4c7e5736
--- /dev/null
+++ b/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp16.cpp
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include "src/core/CPP/Validate.h"
+#include "src/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/CpuTypes.h"
+#include "src/cpu/kernels/fuse_batch_normalization/generic/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp16_batch_normalization_nchw_non_fused(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float16_t, false, detail::dummy<float16_t, 8>>(window, input, output, mean, var, beta,
+ gamma, epsilon, act_info);
+}
+
+void fp16_batch_normalization_nchw_non_fused_relu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float16_t, true, detail::relu<float16_t, 8>>(window, input, output, mean, var, beta, gamma,
+ epsilon, act_info);
+}
+
+void fp16_batch_normalization_nchw_non_fused_brelu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float16_t, true, detail::brelu<float16_t, 8>>(window, input, output, mean, var, beta,
+ gamma, epsilon, act_info);
+}
+
+void fp16_batch_normalization_nchw_non_fused_lubrelu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float16_t, true, detail::lubrelu<float16_t, 8>>(window, input, output, mean, var, beta,
+ gamma, epsilon, act_info);
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp b/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp
new file mode 100644
index 0000000000..ae2db1ac66
--- /dev/null
+++ b/src/cpu/kernels/fuse_batch_normalization/nchw/neon/fp32.cpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+
+#include "src/core/CPP/Validate.h"
+#include "src/core/NEON/kernels/detail/NEActivationFunctionDetail.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/cpu/CpuTypes.h"
+#include "src/cpu/kernels/fuse_batch_normalization/generic/impl.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+void fp32_batch_normalization_nchw_non_fused(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float, false, detail::dummy<float, 4>>(window, input, output, mean, var, beta, gamma,
+ epsilon, act_info);
+}
+
+void fp32_batch_normalization_nchw_non_fused_relu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float, true, detail::relu<float, 4>>(window, input, output, mean, var, beta, gamma,
+ epsilon, act_info);
+}
+
+void fp32_batch_normalization_nchw_non_fused_brelu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float, true, detail::brelu<float, 4>>(window, input, output, mean, var, beta, gamma,
+ epsilon, act_info);
+}
+
+void fp32_batch_normalization_nchw_non_fused_lubrelu(const Window &window,
+ ITensor *input,
+ ITensor *output,
+ const ITensor *mean,
+ const ITensor *var,
+ const ITensor *beta,
+ const ITensor *gamma,
+ float epsilon,
+ ActivationLayerInfo act_info)
+{
+ batch_normalization_nchw<float, true, detail::lubrelu<float, 4>>(window, input, output, mean, var, beta, gamma,
+ epsilon, act_info);
+}
+} // namespace cpu
+} // namespace arm_compute