aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Zlotnik <dana.zlotnik@arm.com>2022-01-03 10:59:41 +0200
committerDana Zlotnik <dana.zlotnik@arm.com>2022-01-12 12:31:14 +0000
commitd7e2ec51239e2075f931e0a9364e0a68534676f1 (patch)
treefa2e37fdd125caba983d3b9cdabdb8e03921f6a6
parent5ae8d804d67f57fbfa793800ddcc21a5aff954dd (diff)
downloadComputeLibrary-d7e2ec51239e2075f931e0a9364e0a68534676f1.tar.gz
Decouple NEInstanceNormalizationLayerKernel
Resolves COMPMID-4620 Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com> Change-Id: I22c285339840493c9cfd4c1abfbc3768ad4db824 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6871 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp3
-rw-r--r--filelist.json7
-rw-r--r--src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp190
-rw-r--r--src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h15
-rw-r--r--src/cpu/kernels/instancenorm/generic/neon/fp16.cpp43
-rw-r--r--src/cpu/kernels/instancenorm/generic/neon/fp32.cpp35
-rw-r--r--src/cpu/kernels/instancenorm/generic/neon/impl.cpp172
-rw-r--r--src/cpu/kernels/instancenorm/generic/neon/impl.h50
-rw-r--r--src/cpu/kernels/instancenorm/list.h37
9 files changed, 398 insertions, 154 deletions
diff --git a/Android.bp b/Android.bp
index 600035ee08..3c9a99b8de 100644
--- a/Android.bp
+++ b/Android.bp
@@ -469,6 +469,9 @@ cc_library_static {
"src/cpu/kernels/genproposals/generic/neon/fp32.cpp",
"src/cpu/kernels/genproposals/generic/neon/impl.cpp",
"src/cpu/kernels/genproposals/generic/neon/qsymm16.cpp",
+ "src/cpu/kernels/instancenorm/generic/neon/fp16.cpp",
+ "src/cpu/kernels/instancenorm/generic/neon/fp32.cpp",
+ "src/cpu/kernels/instancenorm/generic/neon/impl.cpp",
"src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp",
"src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp",
"src/cpu/kernels/maxunpool/generic/neon/fp16.cpp",
diff --git a/filelist.json b/filelist.json
index a093f60bc5..58c2a9a632 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1530,7 +1530,12 @@
"common": [
"src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp",
"src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp"
- ]
+ ],
+ "neon":{
+ "common":["src/cpu/kernels/instancenorm/generic/neon/impl.cpp"],
+ "fp16":["src/cpu/kernels/instancenorm/generic/neon/fp16.cpp"],
+ "fp32":["src/cpu/kernels/instancenorm/generic/neon/fp32.cpp"]
+ }
}
},
"L2Normalize": {
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
index d33431a8d2..71641404bf 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,8 +34,10 @@
#include "src/core/CPP/Validate.h"
#include "src/core/NEON/NEMath.h"
#include "src/core/NEON/wrapper/wrapper.h"
+#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
+#include "src/cpu/kernels/instancenorm/list.h"
#include <arm_neon.h>
@@ -43,137 +45,53 @@ namespace arm_compute
{
namespace
{
-template <typename InputType, typename AccType = InputType>
-void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+struct InstanceNormSelectorData
{
- result = wrapper::vadd(result, inputs);
- result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
-}
+ DataType dt;
+};
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
-{
- vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
- vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+using InstanceNormSelctorPtr = std::add_pointer<bool(const InstanceNormSelectorData &data)>::type;
+using InstanceNormUKernelPtr = std::add_pointer<void(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)>::type;
-template <typename InputType, typename AccType = InputType>
-InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+struct InstanceNormKernel
{
- return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
-}
+ const char *name;
+ const InstanceNormSelctorPtr is_selected;
+ InstanceNormUKernelPtr ukernel;
+};
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-template <>
-inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
+static const InstanceNormKernel available_kernels[] =
{
- const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
- const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
- const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
- const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
- float16x8_t result = wrapper::vcombine(result_low, result_high);
-
- return result;
-}
+ {
+ "fp32_neon_instancenorm",
+ [](const InstanceNormSelectorData & data) { return data.dt == DataType::F32; },
+ REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_instancenorm)
+ },
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ {
+ "fp16_neon_instancenorm",
+ [](const InstanceNormSelectorData & data) { return data.dt == DataType::F16; },
+ REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_instancenorm)
+ },
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+};
-template <typename T, typename AccType = T>
-void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
+/** Micro-kernel selector
+ *
+ * @param[in] data Selection data passed to help pick the appropriate micro-kernel
+ *
+ * @return A matching micro-kernel else nullptr
+ */
+const InstanceNormKernel *get_implementation(const InstanceNormSelectorData &data)
{
- /** SIMD vector tag type. */
- using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
- // Clear X/Y dimensions on execution window as we handle the planes manually
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- win.set(Window::DimY, Window::Dimension(0, 1, 1));
-
- constexpr int window_step_x = 16 / sizeof(T);
- const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
-
- Iterator input_it(input, win);
- execute_window_loop(win, [&](const Coordinates & id)
+ for(const auto &uk : available_kernels)
{
- Window win_plane = window;
- win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
- win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
- win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
-
- Iterator input_plane_it(input, win_plane);
- Iterator output_plane_it(output, win_plane);
-
- auto sum_h_w = static_cast<AccType>(0.f);
- auto sum_squares_h_w = static_cast<AccType>(0.f);
-
- execute_window_loop(win_plane, [&](const Coordinates &)
+ if(uk.is_selected(data))
{
- const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
-
- auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
- auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
-
- // Compute S elements per iteration
- int x = window.x().start();
- for(; x <= (window.x().end() - window_step_x); x += window_step_x)
- {
- auto vec_input_val = wrapper::vloadq(input_ptr + x);
- vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
- }
-
- auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
- auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
-
- vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
- vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
-
- sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
- sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
-
- // Compute left-over elements
- for(; x < window.x().end(); ++x)
- {
- const auto value = static_cast<AccType>(*(input_ptr + x));
- sum_h_w += value;
- sum_squares_h_w += value * value;
- }
- },
- input_plane_it, output_plane_it);
-
- const auto mean_h_w = sum_h_w / elements_plane;
- const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
-
- const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon);
- const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
- const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
- const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
-
- execute_window_loop(win_plane, [&](const Coordinates &)
- {
- auto input_ptr = reinterpret_cast<T *>(input_plane_it.ptr());
- auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
-
- // Compute S elements per iteration
- int x = window.x().start();
- //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
- for(; x <= (window.x().end() - window_step_x); x += window_step_x)
- {
- const auto vec_val = wrapper::vloadq(input_ptr + x);
- const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
- wrapper::vstore(output_ptr + x, normalized_vec);
- }
-
- // Compute left-over elements
- for(; x < window.x().end(); ++x)
- {
- const auto val = static_cast<AccType>(*(input_ptr + x));
- *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
- }
- },
- input_plane_it, output_plane_it);
- },
- input_it);
+ return &uk;
+ }
+ }
+ return nullptr;
}
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
@@ -210,7 +128,7 @@ std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITe
} // namespace
NEInstanceNormalizationLayerKernel::NEInstanceNormalizationLayerKernel()
- : _func(nullptr), _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
+ : _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
{
}
@@ -227,28 +145,6 @@ void NEInstanceNormalizationLayerKernel::configure(ITensor *input, ITensor *outp
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon));
- if(_input->info()->data_type() == DataType::F32)
- {
- _func = &instance_normalization_nchw<float>;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- else if(_input->info()->data_type() == DataType::F16)
- {
- if(_use_mixed_precision)
- {
- _func = &instance_normalization_nchw<float16_t, float>;
- }
- else
- {
- _func = &instance_normalization_nchw<float16_t>;
- }
- }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- else
- {
- ARM_COMPUTE_ERROR("Unsupported data type");
- }
-
// Configure kernel window
auto win_config = validate_and_configure_window(_input->info(), _output->info());
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
@@ -268,6 +164,10 @@ void NEInstanceNormalizationLayerKernel::run(const Window &window, const ThreadI
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- (*_func)(_input, _output, _gamma, _beta, _epsilon, window);
+
+ const auto *uk = get_implementation(InstanceNormSelectorData{ _input->info()->data_type() });
+ ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
+
+ uk->ukernel(_input, _output, _gamma, _beta, _epsilon, _use_mixed_precision, window);
}
} // namespace arm_compute
diff --git a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
index 96c0119719..f166ce2058 100644
--- a/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
+++ b/src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -84,13 +84,12 @@ private:
*/
using NormalizationFunction = void(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
- NormalizationFunction *_func;
- ITensor *_input;
- ITensor *_output;
- float _gamma;
- float _beta;
- float _epsilon;
- bool _use_mixed_precision{ true };
+ ITensor *_input;
+ ITensor *_output;
+ float _gamma;
+ float _beta;
+ float _epsilon;
+ bool _use_mixed_precision{ true };
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEINSTANCENORMALIZATIONLAYERKERNEL_H */
diff --git a/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp b/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp
new file mode 100644
index 0000000000..e9fcc84b35
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/fp16.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_fp16_instancenorm(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+{
+ if(use_mixed_precision)
+ {
+ return instance_normalization_nchw<float16_t, float>(input, output, gamma, beta, epsilon, window);
+ }
+ else
+ {
+ return instance_normalization_nchw<float16_t>(input, output, gamma, beta, epsilon, window);
+ }
+}
+} // namespace cpu
+} // namespace arm_compute
+#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp b/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp
new file mode 100644
index 0000000000..061dd9585c
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/fp32.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_fp32_instancenorm(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(use_mixed_precision);
+ return instance_normalization_nchw<float>(input, output, gamma, beta, epsilon, window);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/instancenorm/generic/neon/impl.cpp b/src/cpu/kernels/instancenorm/generic/neon/impl.cpp
new file mode 100644
index 0000000000..e35cf97608
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/impl.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (c) 2019-2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/cpu/kernels/instancenorm/generic/neon/impl.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+namespace arm_compute
+{
+class ITensor;
+class Window;
+namespace cpu
+{
+template <typename InputType, typename AccType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
+{
+ result = wrapper::vadd(result, inputs);
+ result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
+{
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
+ vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
+}
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
+{
+ const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
+ const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
+ const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
+ const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
+ float16x8_t result = wrapper::vcombine(result_low, result_high);
+
+ return result;
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+template <typename InputType, typename AccType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
+{
+ return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template <typename T, typename AccType>
+void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
+{
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ // Clear X/Y dimensions on execution window as we handle the planes manually
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+ win.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ constexpr int window_step_x = 16 / sizeof(T);
+ const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
+
+ Iterator input_it(input, win);
+ execute_window_loop(win, [&](const Coordinates & id)
+ {
+ Window win_plane = window;
+ win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
+ win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
+ win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
+
+ Iterator input_plane_it(input, win_plane);
+ Iterator output_plane_it(output, win_plane);
+
+ auto sum_h_w = static_cast<AccType>(0.f);
+ auto sum_squares_h_w = static_cast<AccType>(0.f);
+
+ execute_window_loop(win_plane, [&](const Coordinates &)
+ {
+ const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
+
+ auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+ auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
+
+ // Compute S elements per iteration
+ int x = window.x().start();
+ for(; x <= (window.x().end() - window_step_x); x += window_step_x)
+ {
+ auto vec_input_val = wrapper::vloadq(input_ptr + x);
+ vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
+ }
+
+ auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
+ auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
+
+ vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
+ vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
+
+ sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
+ sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
+
+ // Compute left-over elements
+ for(; x < window.x().end(); ++x)
+ {
+ const auto value = static_cast<AccType>(*(input_ptr + x));
+ sum_h_w += value;
+ sum_squares_h_w += value * value;
+ }
+ },
+ input_plane_it, output_plane_it);
+
+ const auto mean_h_w = sum_h_w / elements_plane;
+ const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
+
+ const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon);
+ const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
+ const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
+ const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
+
+ execute_window_loop(win_plane, [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<T *>(input_plane_it.ptr());
+ auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
+
+ // Compute S elements per iteration
+ int x = window.x().start();
+ //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
+ for(; x <= (window.x().end() - window_step_x); x += window_step_x)
+ {
+ const auto vec_val = wrapper::vloadq(input_ptr + x);
+ const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
+ wrapper::vstore(output_ptr + x, normalized_vec);
+ }
+
+ // Compute left-over elements
+ for(; x < window.x().end(); ++x)
+ {
+ const auto val = static_cast<AccType>(*(input_ptr + x));
+ *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
+ }
+ },
+ input_plane_it, output_plane_it);
+ },
+ input_it);
+}
+
+template void instance_normalization_nchw<float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+template void instance_normalization_nchw<float16_t, float>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+template void instance_normalization_nchw<float16_t>(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/instancenorm/generic/neon/impl.h b/src/cpu/kernels/instancenorm/generic/neon/impl.h
new file mode 100644
index 0000000000..fa4b4b656c
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/generic/neon/impl.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
+#define SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
+#include "arm_compute/core/Helpers.h"
+namespace arm_compute
+{
+namespace cpu
+{
+template <typename T, typename AccType = T>
+void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window);
+
+template <typename InputType, typename AccType = InputType>
+void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs);
+
+template <typename InputType, typename AccType = InputType>
+InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta);
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+template <>
+inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs);
+
+template <>
+inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta);
+#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
+
+} // namespace cpu
+} // namespace arm_compute
+#endif //define SRC_CORE_SVE_KERNELS_INSTANCENORM_IMPL_H
diff --git a/src/cpu/kernels/instancenorm/list.h b/src/cpu/kernels/instancenorm/list.h
new file mode 100644
index 0000000000..54f1d3213f
--- /dev/null
+++ b/src/cpu/kernels/instancenorm/list.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H
+#define SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H
+namespace arm_compute
+{
+namespace cpu
+{
+#define DECLARE_INSTANCENORM_KERNEL(func_name) \
+ void func_name(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, bool use_mixed_precision, const Window &window)
+DECLARE_INSTANCENORM_KERNEL(neon_fp32_instancenorm);
+DECLARE_INSTANCENORM_KERNEL(neon_fp16_instancenorm);
+#undef DECLARE_INSTANCENORM_KERNEL
+} // namespace cpu
+} // namespace arm_compute
+#endif //SRC_CORE_NEON_KERNELS_INSTANCENORM_LIST_H