aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 38e367dfb7..e8882b9daf 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -130,7 +130,8 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts));
// Set kernel static arguments
- unsigned int idx = 2 * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
+ unsigned int include_output = (output != nullptr) ? 1 : 0;
+ unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
_kernel.setArg<cl_float>(idx++, _epsilon);
// Configure kernel window
@@ -160,7 +161,8 @@ void CLBatchNormalizationLayerKernel::run(const Window &window, cl::CommandQueue
Window vector_slice = window.first_slice_window_1D();
vector_slice.set(Window::DimX, Window::Dimension(0, 0, 0));
- unsigned int idx = 2 * num_arguments_per_3D_tensor();
+ unsigned int include_output = (_output != nullptr) ? 1 : 0;
+ unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor();
add_1D_tensor_argument(idx, _mean, vector_slice);
add_1D_tensor_argument(idx, _var, vector_slice);
add_1D_tensor_argument(idx, _beta, vector_slice);