From 236bfe7033a313ab98ff436d85f38a58b0738ed1 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 23 Nov 2017 15:59:55 +0000 Subject: COMPIMID-553: MobileNet use case. Change-Id: I1181abbd5785065f3d57e91844376a4b110938a9 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110701 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Anthony Barbier --- src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp') diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp index 38e367dfb7..e8882b9daf 100644 --- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp @@ -130,7 +130,8 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out _kernel = static_cast(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts)); // Set kernel static arguments - unsigned int idx = 2 * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters + unsigned int include_output = (output != nullptr) ? 1 : 0; + unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters _kernel.setArg(idx++, _epsilon); // Configure kernel window @@ -160,7 +161,8 @@ void CLBatchNormalizationLayerKernel::run(const Window &window, cl::CommandQueue Window vector_slice = window.first_slice_window_1D(); vector_slice.set(Window::DimX, Window::Dimension(0, 0, 0)); - unsigned int idx = 2 * num_arguments_per_3D_tensor(); + unsigned int include_output = (_output != nullptr) ? 1 : 0; + unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor(); add_1D_tensor_argument(idx, _mean, vector_slice); add_1D_tensor_argument(idx, _var, vector_slice); add_1D_tensor_argument(idx, _beta, vector_slice); -- cgit v1.2.1