aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl')
-rw-r--r--src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl46
1 files changed, 29 insertions, 17 deletions
diff --git a/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl b/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl
index 0883cd99c8..9eb995fbb2 100644
--- a/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl
+++ b/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -75,7 +75,11 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc(
{
const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM
const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES
- const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#if defined(IS_BATCHED)
+ const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else // defined(IS_BATCHED)
+ const int bout = 0; // BATCH SIZE IDX
+#endif // defined(IS_BATCHED)
int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W;
int y_out = (mout / NUM_TILES_X) * OUTPUT_TILE_H;
@@ -103,7 +107,7 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc(
// Compute out0 and out01
out[0].v = in[0].v + in[1].v + in[2].v + in[3].v + in[4].v + in[5].v + in[6].v;
- out[1].v = -in[1].v + in[2].v - 2.f * in[3].v + 2.0f * in[4].v - 3.0f * in[5].v + 3.0f * in[6].v + in[7].v;
+ out[1].v = -in[1].v + in[2].v - (DATA_TYPE)2.f * in[3].v + (DATA_TYPE)2.0f * in[4].v - (DATA_TYPE)3.0f * in[5].v + (DATA_TYPE)3.0f * in[6].v + in[7].v;
#if defined(HAS_BIAS)
// Add bias
@@ -161,14 +165,14 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc(
LOOP_UNROLLING(int, i, 0, 1, 8,
{
tmp[i * 2].v = in[0 + i].v + in[8 + i].v + in[16 + i].v + in[24 + i].v + in[32 + i].v + in[40 + i].v + in[48 + i].v;
- tmp[i * 2 + 1].v = -in[8 + i].v + in[16 + i].v - 2 * in[24 + i].v + 2 * in[32 + i].v + -3 * in[40 + i].v + 3 * in[48 + i].v + in[56 + i].v;
+ tmp[i * 2 + 1].v = -in[8 + i].v + in[16 + i].v - (DATA_TYPE)2 * in[24 + i].v + (DATA_TYPE)2 * in[32 + i].v + (DATA_TYPE) - 3 * in[40 + i].v + (DATA_TYPE)3 * in[48 + i].v + in[56 + i].v;
})
// Compute the 2x2 output tile
LOOP_UNROLLING(int, i, 0, 1, 2,
{
out[i * 2].v = tmp[0 + i].v + tmp[2 + i].v + tmp[4 + i].v + tmp[6 + i].v + tmp[8 + i].v + tmp[10 + i].v + tmp[12 + i].v;
- out[i * 2 + 1].v = -tmp[2 + i].v + tmp[4 + i].v - 2 * tmp[6 + i].v + 2 * tmp[8 + i].v - 3 * tmp[10 + i].v + 3 * tmp[12 + i].v + tmp[14 + i].v;
+ out[i * 2 + 1].v = -tmp[2 + i].v + tmp[4 + i].v - (DATA_TYPE)2 * tmp[6 + i].v + (DATA_TYPE)2 * tmp[8 + i].v - (DATA_TYPE)3 * tmp[10 + i].v + (DATA_TYPE)3 * tmp[12 + i].v + tmp[14 + i].v;
})
#if defined(HAS_BIAS)
@@ -252,7 +256,11 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
{
const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM
const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES
- const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#if defined(IS_BATCHED)
+ const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else // defined(IS_BATCHED)
+ const int bout = 0; // BATCH SIZE IDX
+#endif // defined(IS_BATCHED)
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
@@ -277,9 +285,9 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc(
// Compute out00, out01, out02 and out03
out[0].v = in[0].v + in[1].v + in[2].v + in[3].v + in[4].v;
- out[1].v = in[1].v - in[2].v + 2.0f * in[3].v - 2.0f * in[4].v;
- out[2].v = in[1].v + in[2].v + 4.0f * in[3].v + 4.0f * in[4].v;
- out[3].v = in[1].v - in[2].v + 8.0f * in[3].v - 8.0f * in[4].v + in[5].v;
+ out[1].v = in[1].v - in[2].v + (DATA_TYPE)2.0f * in[3].v - (DATA_TYPE)2.0f * in[4].v;
+ out[2].v = in[1].v + in[2].v + (DATA_TYPE)4.0f * in[3].v + (DATA_TYPE)4.0f * in[4].v;
+ out[3].v = in[1].v - in[2].v + (DATA_TYPE)8.0f * in[3].v - (DATA_TYPE)8.0f * in[4].v + in[5].v;
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
@@ -449,7 +457,11 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
{
const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM
const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES
- const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#if defined(IS_BATCHED)
+ const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX
+#else // defined(IS_BATCHED)
+ const int bout = 0; // BATCH SIZE IDX
+#endif // defined(IS_BATCHED)
#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
TILE(DATA_TYPE, 8, N0, in);
@@ -474,13 +486,13 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc(
// A^T * in, and in this degenerate case out consists of 1 column/row
tmp[0].v = in[1].v - in[2].v;
- tmp[1].v = 2.0f * (in[3].v - in[4].v);
- tmp[2].v = 2.0f * (in[5].v + in[6].v);
+ tmp[1].v = (DATA_TYPE)2.0f * (in[3].v - in[4].v);
+ tmp[2].v = (DATA_TYPE)2.0f * (in[5].v + in[6].v);
tmp[3].v = in[3].v + in[4].v;
- out[0].v = in[0].v + in[1].v + in[2].v + tmp[3].v + 4.0f * tmp[2].v;
- out[1].v = tmp[0].v + tmp[1].v + 4.0f * (in[5].v - in[6].v);
- out[2].v = in[1].v + in[2].v + 4.0f * tmp[3].v + tmp[2].v;
- out[3].v = tmp[0].v + 4.0f * tmp[1].v + in[5].v - in[6].v + in[7].v;
+ out[0].v = in[0].v + in[1].v + in[2].v + tmp[3].v + (DATA_TYPE)4.0f * tmp[2].v;
+ out[1].v = tmp[0].v + tmp[1].v + (DATA_TYPE)4.0f * (in[5].v - in[6].v);
+ out[2].v = in[1].v + in[2].v + (DATA_TYPE)4.0f * tmp[3].v + tmp[2].v;
+ out[3].v = tmp[0].v + (DATA_TYPE)4.0f * tmp[1].v + in[5].v - in[6].v + in[7].v;
#if defined(HAS_BIAS)
TILE(DATA_TYPE, 1, N0, b);
@@ -1094,4 +1106,4 @@ __kernel void winograd_output_transform_1x4_1x5_nhwc(
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_1X4_1X5_NHWC)
#endif // defined(VEC_SIZE) && VEC_SIZE == 4
#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
-#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) \ No newline at end of file
+#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)