aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/batchnormalization_layer.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/batchnormalization_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl29
1 files changed, 15 insertions, 14 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index dfd16e0da3..60307bc9a7 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -341,22 +341,10 @@ __kernel void fuse_batchnormalization_layer(TENSOR4D_DECLARATION(conv_w),
Vector bn_mean = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_mean);
Vector bn_var = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_var);
- // In-place ops
-#ifdef IN_PLACE_W
- Tensor4D fused_w = conv_w;
-#else /* IN_PLACE_W */
- Tensor4D fused_w = CONVERT_TO_TENSOR4D_STRUCT(fused_w, NUM_CHANNELS);
-#endif /* IN_PLACE */
-#ifdef IN_PLACE_B
- Vector fused_b = conv_b;
-#else /* IN_PLACE_W */
- Vector fused_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(fused_b);
-#endif /* IN_PLACE */
-
// Conditional ops
#ifdef HAS_BIAS
Vector conv_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(conv_b);
-#endif /* USE_DEFAULT_BETA */
+#endif /* HAS_BIAS */
#ifndef USE_DEFAULT_BETA
Vector bn_beta = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_beta);
#endif /* USE_DEFAULT_BETA */
@@ -364,6 +352,19 @@ __kernel void fuse_batchnormalization_layer(TENSOR4D_DECLARATION(conv_w),
Vector bn_gamma = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bn_gamma);
#endif /* USE_DEFAULT_GAMMA */
+ // In-place ops
+#ifdef IN_PLACE_W
+ Tensor4D fused_w = conv_w;
+ uint fused_w_stride_x = conv_w_stride_x;
+#else /* IN_PLACE_W */
+ Tensor4D fused_w = CONVERT_TO_TENSOR4D_STRUCT(fused_w, NUM_CHANNELS);
+#endif /* IN_PLACE_W */
+#ifdef IN_PLACE_B
+ Vector fused_b = conv_b;
+#else /* IN_PLACE_B */
+ Vector fused_b = CONVERT_TO_VECTOR_STRUCT_NO_STEP(fused_b);
+#endif /* IN_PLACE_B */
+
const int current_slice = get_global_id(2) / NUM_CHANNELS;
#if defined(VEC_SIZE) && defined(LAST_ACCESSED_X)