aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/fuse_batch_normalization/nhwc
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/fuse_batch_normalization/nhwc')
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp16
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp16
-rw-r--r--src/cpu/kernels/fuse_batch_normalization/nhwc/neon/impl.h143
3 files changed, 105 insertions, 70 deletions
diff --git a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp
index 275211ff38..1d88d3b494 100644
--- a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp
+++ b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp16.cpp
@@ -30,11 +30,19 @@ namespace arm_compute
{
namespace cpu
{
-void fused_batch_normalization_dwc_nhwc_f16(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias,
- const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
+void fused_batch_normalization_dwc_nhwc_f16(const ITensor *dwc_weights,
+ const ITensor *dwc_bias,
+ ITensor *fused_weights,
+ ITensor *fused_bias,
+ const ITensor *bn_mean,
+ const ITensor *bn_var,
+ const ITensor *bn_beta,
+ const ITensor *bn_gamma,
+ float epsilon,
+ const Window &window)
{
- return fused_batch_normalization_dwc_nhwc<float16_t>(dwc_weights, dwc_bias, fused_weights, fused_bias,
- bn_mean, bn_var, bn_beta, bn_gamma, epsilon, window);
+ return fused_batch_normalization_dwc_nhwc<float16_t>(dwc_weights, dwc_bias, fused_weights, fused_bias, bn_mean,
+ bn_var, bn_beta, bn_gamma, epsilon, window);
}
} // namespace cpu
diff --git a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp
index 67169c5325..1f336bb196 100644
--- a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp
+++ b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/fp32.cpp
@@ -29,11 +29,19 @@ namespace arm_compute
{
namespace cpu
{
-void fused_batch_normalization_dwc_nhwc_f32(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias,
- const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
+void fused_batch_normalization_dwc_nhwc_f32(const ITensor *dwc_weights,
+ const ITensor *dwc_bias,
+ ITensor *fused_weights,
+ ITensor *fused_bias,
+ const ITensor *bn_mean,
+ const ITensor *bn_var,
+ const ITensor *bn_beta,
+ const ITensor *bn_gamma,
+ float epsilon,
+ const Window &window)
{
- return fused_batch_normalization_dwc_nhwc<float32_t>(dwc_weights, dwc_bias, fused_weights, fused_bias,
- bn_mean, bn_var, bn_beta, bn_gamma, epsilon, window);
+ return fused_batch_normalization_dwc_nhwc<float32_t>(dwc_weights, dwc_bias, fused_weights, fused_bias, bn_mean,
+ bn_var, bn_beta, bn_gamma, epsilon, window);
}
} // namespace cpu
diff --git a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/impl.h b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/impl.h
index 6f0386276f..5b74a7aef6 100644
--- a/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/impl.h
+++ b/src/cpu/kernels/fuse_batch_normalization/nhwc/neon/impl.h
@@ -25,6 +25,7 @@
#define SRC_CORE_NEON_KERNELS_FUSE_BATCH_NORMALIZATION_IMPL_H
#include "arm_compute/core/Helpers.h"
+
#include "src/core/NEON/wrapper/wrapper.h"
namespace arm_compute
@@ -32,8 +33,16 @@ namespace arm_compute
namespace cpu
{
template <typename T>
-void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias,
- const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
+void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights,
+ const ITensor *dwc_bias,
+ ITensor *fused_weights,
+ ITensor *fused_bias,
+ const ITensor *bn_mean,
+ const ITensor *bn_var,
+ const ITensor *bn_beta,
+ const ITensor *bn_gamma,
+ float epsilon,
+ const Window &window)
{
using ScalarType = T;
const int size = 16 / dwc_weights->info()->element_size();
@@ -53,13 +62,20 @@ void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights, const ITenso
Iterator dwc_w_in(dwc_weights, win);
Iterator dwc_w_out(run_in_place_weights ? dwc_weights : fused_weights, win);
- const auto dwc_bias_in = (dwc_bias != nullptr ? reinterpret_cast<ScalarType *>(dwc_bias->ptr_to_element(Coordinates(0, 0))) : nullptr);
- auto dwc_bias_out = (run_in_place_bias ? dwc_bias_in : reinterpret_cast<ScalarType *>(fused_bias->ptr_to_element(Coordinates(0, 0))));
+ const auto dwc_bias_in =
+ (dwc_bias != nullptr ? reinterpret_cast<ScalarType *>(dwc_bias->ptr_to_element(Coordinates(0, 0))) : nullptr);
+ auto dwc_bias_out =
+ (run_in_place_bias ? dwc_bias_in
+ : reinterpret_cast<ScalarType *>(fused_bias->ptr_to_element(Coordinates(0, 0))));
const auto input_mean = reinterpret_cast<const ScalarType *>(bn_mean->ptr_to_element(Coordinates(0, 0)));
const auto input_var = reinterpret_cast<const ScalarType *>(bn_var->ptr_to_element(Coordinates(0, 0)));
- const auto input_gamma = (bn_gamma != nullptr) ? reinterpret_cast<const ScalarType *>(bn_gamma->ptr_to_element(Coordinates(0, 0))) : nullptr;
- const auto input_beta = (bn_beta != nullptr) ? reinterpret_cast<const ScalarType *>(bn_beta->ptr_to_element(Coordinates(0, 0))) : nullptr;
+ const auto input_gamma = (bn_gamma != nullptr)
+ ? reinterpret_cast<const ScalarType *>(bn_gamma->ptr_to_element(Coordinates(0, 0)))
+ : nullptr;
+ const auto input_beta = (bn_beta != nullptr)
+ ? reinterpret_cast<const ScalarType *>(bn_beta->ptr_to_element(Coordinates(0, 0)))
+ : nullptr;
auto mean_vec = wrapper::vdup_n(ScalarType(0), ExactTagType{});
auto var_vec = wrapper::vdup_n(ScalarType(0), ExactTagType{});
@@ -73,81 +89,84 @@ void fused_batch_normalization_dwc_nhwc(const ITensor *dwc_weights, const ITenso
auto beta = ScalarType(0.0);
auto dwc_bias_in_scalar = ScalarType(0);
- execute_window_loop(win, [&](const Coordinates & id)
- {
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ execute_window_loop(
+ win,
+ [&](const Coordinates &id)
{
- var_vec = wrapper::vloadq(input_var + x);
- if(input_gamma != nullptr)
- {
- gamma_vec = wrapper::vloadq(input_gamma + x);
- }
-
- if((id[2] == 0) && (id[1] == 0))
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
{
- mean_vec = wrapper::vloadq(input_mean + x);
-
- // Construct vectors
- if(input_beta != nullptr)
+ var_vec = wrapper::vloadq(input_var + x);
+ if (input_gamma != nullptr)
{
- beta_vec = wrapper::vloadq(input_beta + x);
+ gamma_vec = wrapper::vloadq(input_gamma + x);
}
- if(dwc_bias_in != nullptr)
+ if ((id[2] == 0) && (id[1] == 0))
{
- dwc_bias_vec = wrapper::vloadq(dwc_bias_in + x);
+ mean_vec = wrapper::vloadq(input_mean + x);
+
+ // Construct vectors
+ if (input_beta != nullptr)
+ {
+ beta_vec = wrapper::vloadq(input_beta + x);
+ }
+
+ if (dwc_bias_in != nullptr)
+ {
+ dwc_bias_vec = wrapper::vloadq(dwc_bias_in + x);
+ }
+
+ auto dwc_bias_tmp_vec = wrapper::vmul(wrapper::vsub(dwc_bias_vec, mean_vec),
+ wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec)));
+ dwc_bias_tmp_vec = wrapper::vadd(wrapper::vmul(dwc_bias_tmp_vec, gamma_vec), beta_vec);
+ wrapper::vstore(dwc_bias_out + x, dwc_bias_tmp_vec);
}
- auto dwc_bias_tmp_vec = wrapper::vmul(wrapper::vsub(dwc_bias_vec, mean_vec), wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec)));
- dwc_bias_tmp_vec = wrapper::vadd(wrapper::vmul(dwc_bias_tmp_vec, gamma_vec), beta_vec);
- wrapper::vstore(dwc_bias_out + x, dwc_bias_tmp_vec);
- }
-
- auto dwc_w_in_ptr = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
- auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
+ auto dwc_w_in_ptr = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
+ auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
- auto wn = wrapper::vloadq(dwc_w_in_ptr + x);
- rvar_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
- wn = wrapper::vmul(wn, rvar_vec);
- wn = wrapper::vmul(wn, gamma_vec);
+ auto wn = wrapper::vloadq(dwc_w_in_ptr + x);
+ rvar_vec = wrapper::vinvsqrt(wrapper::vadd(var_vec, epsilon_vec));
+ wn = wrapper::vmul(wn, rvar_vec);
+ wn = wrapper::vmul(wn, gamma_vec);
- // Store results
- wrapper::vstore(dwc_w_out_ptr + x, wn);
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- auto var = input_var[x];
- if(input_gamma != nullptr)
- {
- gamma = input_gamma[x];
+ // Store results
+ wrapper::vstore(dwc_w_out_ptr + x, wn);
}
- if(id[2] == 0 && id[1] == 0)
+ // Compute left-over elements
+ for (; x < window_end_x; ++x)
{
- auto mean = input_mean[x];
- if(input_beta != nullptr)
+ auto var = input_var[x];
+ if (input_gamma != nullptr)
{
- beta = input_beta[x];
+ gamma = input_gamma[x];
}
- if(dwc_bias_in != nullptr)
+
+ if (id[2] == 0 && id[1] == 0)
{
- dwc_bias_in_scalar = dwc_bias_in[x];
+ auto mean = input_mean[x];
+ if (input_beta != nullptr)
+ {
+ beta = input_beta[x];
+ }
+ if (dwc_bias_in != nullptr)
+ {
+ dwc_bias_in_scalar = dwc_bias_in[x];
+ }
+
+ auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
+ dwc_bias_out[x] = (dwc_bias_tmp_scalar * gamma) + beta;
}
- auto dwc_bias_tmp_scalar = (dwc_bias_in_scalar - mean) / std::sqrt(var + ScalarType(epsilon));
- dwc_bias_out[x] = (dwc_bias_tmp_scalar * gamma) + beta;
- }
-
- const auto dwc_w_in_ptr = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
- auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
+ const auto dwc_w_in_ptr = reinterpret_cast<const ScalarType *>(dwc_w_in.ptr());
+ auto dwc_w_out_ptr = reinterpret_cast<ScalarType *>(dwc_w_out.ptr());
- *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
- }
- },
- dwc_w_in, dwc_w_out);
+ *(dwc_w_out_ptr + x) = *(dwc_w_in_ptr + x) / std::sqrt(var + ScalarType(epsilon)) * gamma;
+ }
+ },
+ dwc_w_in, dwc_w_out);
}
} // namespace cpu
} // namespace arm_compute