aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/batchnormalization_layer.cl
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-06-26 14:18:47 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commit172e57028ef14f2f8d6c56edc53c5c85f97e07cd (patch)
treeb3fe8c05902f07fb2381cf6dfd893654c8ccb63f /src/core/CL/cl_kernels/batchnormalization_layer.cl
parent579c0498e161215be1a36080b0b454e5198a992a (diff)
downloadComputeLibrary-172e57028ef14f2f8d6c56edc53c5c85f97e07cd.tar.gz
COMPMID-425 Port CLBatchnormalization to support QS8/QS16
Change-Id: I46c93305f377666ea0915ff789b7dfdfff596087 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78862 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/batchnormalization_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/batchnormalization_layer.cl69
1 files changed, 48 insertions, 21 deletions
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index 13e6702334..cb4d0c8947 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -21,11 +21,31 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
#include "helpers.h"
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
+#define SUB_OP(a, b) SUB_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
+#define INVSQRT_OP(a) INVSQRT_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
+#define SQCVT_SAT(a) SQCVT_SAT_OP_EXPAND((a), DATA_TYPE, FIXED_POINT_POSITION)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define SUB_OP(a, b) ((a) - (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define INVSQRT_OP(a) rsqrt((a))
+#define SQCVT_SAT(a) (a)
+
+#endif /* FIXED_POINT_POSITION */
+
/** Apply batch normalization.
*
- * @param[in] input_ptr Pointer to the first source tensor. Supported data types: F32
+ * @param[in] input_ptr Pointer to the first source tensor. Supported data types: QS8/QS16/F32
* @param[in] input_stride_x Stride of the first source tensor in X dimension (in bytes)
* @param[in] input_step_x input_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] input_stride_y Stride of the first source tensor in Y dimension (in bytes)
@@ -33,7 +53,7 @@
* @param[in] input_stride_z Stride of the first source tensor in Z dimension (in bytes)
* @param[in] input_step_z input_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] input_offset_first_element_in_bytes The offset of the first element in the first source tensor
- * @param[out] output_ptr Pointer to the destination tensor. Supported data types: F32
+ * @param[out] output_ptr Pointer to the destination tensor. Supported data types: same as @p input_ptr
* @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes)
@@ -41,19 +61,19 @@
* @param[in] output_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] output_step_z output_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: F32
+ * @param[in] mean_ptr Pointer to the mean source tensor. Supported data types: same as @p input_ptr
* @param[in] mean_stride_x Stride of the mean source tensor in X dimension (in bytes)
* @param[in] mean_step_x mean_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] mean_offset_first_element_in_bytes The offset of the first element in the mean source tensor
- * @param[in] var_ptr Pointer to the var tensor. Supported data types: F32
+ * @param[in] var_ptr Pointer to the var tensor. Supported data types: same as @p input_ptr
* @param[in] var_stride_x Stride of the var tensor in X dimension (in bytes)
* @param[in] var_step_x var_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] var_offset_first_element_in_bytes The offset of the first element in the var source tensor
- * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: F32
+ * @param[in] beta_ptr Pointer to the beta source tensor. Supported data types: same as @p input_ptr
* @param[in] beta_stride_x Stride of the beta source tensor in X dimension (in bytes)
* @param[in] beta_step_x beta_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] beta_offset_first_element_in_bytes The offset of the first element in the beta source tensor
- * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: F32
+ * @param[in] gamma_ptr Pointer to the gamma source tensor. Supported data types: same as @p input_ptr
* @param[in] gamma_stride_x Stride of the gamma source tensor in X dimension (in bytes)
* @param[in] gamma_step_x gamma_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] gamma_offset_first_element_in_bytes The offset of the first element in the gamma source tensor
@@ -74,26 +94,33 @@ __kernel void batchnormalization_layer(TENSOR3D_DECLARATION(input),
Vector beta = CONVERT_TO_VECTOR_STRUCT(beta);
Vector gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
- float4 _in = 0;
- float4 denominator = 0;
- float4 numerator = 0;
- float4 x_bar = 0;
- float4 gamma_vec = 0;
- float4 beta_vec = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ _in = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ denominator = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ numerator = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ x_bar = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ gamma_vec = 0;
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+ beta_vec = 0;
const int current_slice = get_global_id(2);
- _in = vload4(0, (__global float *)in.ptr);
- denominator = *((__global float *)(var.ptr + current_slice * var.stride_x));
- denominator = rsqrt(denominator + epsilon);
+ _in = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+ denominator = *((__global DATA_TYPE *)(var.ptr + current_slice * var.stride_x));
+ denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(epsilon)));
// Calculate x bar and store results
- numerator = *((__global float *)(mean.ptr + current_slice * mean.stride_x));
- numerator = _in - numerator;
- x_bar = numerator * denominator;
+ numerator = *((__global DATA_TYPE *)(mean.ptr + current_slice * mean.stride_x));
+ numerator = SUB_OP(_in, numerator);
+ x_bar = MUL_OP(numerator, denominator);
- gamma_vec = *((__global float *)(gamma.ptr + current_slice * beta.stride_x));
- beta_vec = *((__global float *)(beta.ptr + current_slice * beta.stride_x));
+ gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * beta.stride_x));
+ beta_vec = *((__global DATA_TYPE *)(beta.ptr + current_slice * beta.stride_x));
- vstore4(gamma_vec * x_bar + beta_vec, 0, (__global float *)out.ptr);
+ VSTORE(VEC_SIZE)
+ (ADD_OP(MUL_OP(gamma_vec, x_bar), beta_vec), 0, (__global DATA_TYPE *)out.ptr);
}