diff options
Diffstat (limited to 'src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp index 483e5c110b..45210d7976 100644 --- a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp +++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp @@ -82,6 +82,7 @@ void Transform::process_tile( const int n_channels, const float* const matrix_base, const int matrix_stride, + const float* const biases, float* const output, const int output_row_stride, const int output_col_stride @@ -100,6 +101,7 @@ void Transform::process_tile( } } const float *inptr = matrix_base; + const float *bptr = biases; // For each channel of the output int channels_remaining = n_channels; @@ -107,7 +109,7 @@ void Transform::process_tile( for (; channels_remaining >= 4; channels_remaining -= 4) { // Matrices used and computed during this transform - float32x4_t F[6][6], FZ[6][4], f[4][4]; + float32x4_t F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain for (int i = 0, m = 0; i < 6; i++) @@ -152,11 +154,13 @@ void Transform::process_tile( } // Write out the output tile + b = vld1q_f32(bptr); + bptr += 4; for (int i = 0; i < cells_i; i++) { for (int j = 0; j < cells_j; j++) { - vst1q_f32(outptrs[i][j], f[i][j]); + vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b)); outptrs[i][j] += 4; } } @@ -166,7 +170,7 @@ void Transform::process_tile( for (; channels_remaining >= 2; channels_remaining -= 2) { // Matrices used and computed during this transform - float32x2_t F[6][6], FZ[6][4], f[4][4]; + float32x2_t F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain for (int i = 0, m = 0; i < 6; i++) @@ -211,11 +215,13 @@ void Transform::process_tile( } // Write out the output tile + b = vld1_f32(bptr); + bptr += 2; for (int i = 0; i < cells_i; i++) { for (int j = 0; j < cells_j; j++) { - vst1_f32(outptrs[i][j], f[i][j]); + vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b)); outptrs[i][j] += 2; } } @@ -224,7 +230,7 @@ void Transform::process_tile( for (; channels_remaining; channels_remaining--) { // Matrices used and computed during this transform - float F[6][6], FZ[6][4], f[4][4]; + float F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain for (int i = 0, m = 0; i < 6; i++) @@ -255,11 +261,12 @@ void Transform::process_tile( } // Write out the output tile + b = *(bptr++); for (int i = 0; i < cells_i; i++) { for (int j = 0; j < cells_j; j++) { - *(outptrs[i][j]++) = f[i][j]; + *(outptrs[i][j]++) = f[i][j] + b; } } } |