aboutsummaryrefslogtreecommitdiff
path: root/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp')
-rw-r--r--src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
index 982143f0b2..dee2a5579b 100644
--- a/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCBatchNormalizationLayerKernel.cpp
@@ -64,7 +64,11 @@ void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTenso
_gamma = gamma;
_epsilon = epsilon;
- const unsigned int num_elems_processed_per_iteration = 4 / input->info()->element_size();
+ unsigned int num_elems_processed_per_iteration = 1;
+ if(input->info()->data_type() == DataType::F16)
+ {
+ num_elems_processed_per_iteration = 4;
+ }
// Set build options
std::set<std::string> build_opts;
@@ -83,10 +87,10 @@ void GCBatchNormalizationLayerKernel::configure(const IGCTensor *input, IGCTenso
AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);
AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
- AccessWindowStatic mean_access(mean->info(), 0, 0, mean->info()->dimension(0) + 1, mean->info()->dimension(1));
- AccessWindowStatic var_access(var->info(), 0, 0, var->info()->dimension(0) + 1, var->info()->dimension(1));
- AccessWindowStatic beta_access(beta->info(), 0, 0, beta->info()->dimension(0) + 1, beta->info()->dimension(1));
- AccessWindowStatic gamma_access(gamma->info(), 0, 0, gamma->info()->dimension(0) + 1, gamma->info()->dimension(1));
+ AccessWindowStatic mean_access(mean->info(), 0, 0, mean->info()->dimension(0) + 3, mean->info()->dimension(1));
+ AccessWindowStatic var_access(var->info(), 0, 0, var->info()->dimension(0) + 3, var->info()->dimension(1));
+ AccessWindowStatic beta_access(beta->info(), 0, 0, beta->info()->dimension(0) + 3, beta->info()->dimension(1));
+ AccessWindowStatic gamma_access(gamma->info(), 0, 0, gamma->info()->dimension(0) + 3, gamma->info()->dimension(1));
update_window_and_padding(win, input_access, output_access, mean_access, var_access, beta_access, gamma_access);
output_access.set_valid_region(win, input->info()->valid_region());