aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp19
1 files changed, 13 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
index 483e5c110b..45210d7976 100644
--- a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
+++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
@@ -82,6 +82,7 @@ void Transform::process_tile(
const int n_channels,
const float* const matrix_base,
const int matrix_stride,
+ const float* const biases,
float* const output,
const int output_row_stride,
const int output_col_stride
@@ -100,6 +101,7 @@ void Transform::process_tile(
}
}
const float *inptr = matrix_base;
+ const float *bptr = biases;
// For each channel of the output
int channels_remaining = n_channels;
@@ -107,7 +109,7 @@ void Transform::process_tile(
for (; channels_remaining >= 4; channels_remaining -= 4)
{
// Matrices used and computed during this transform
- float32x4_t F[6][6], FZ[6][4], f[4][4];
+ float32x4_t F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -152,11 +154,13 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = vld1q_f32(bptr);
+ bptr += 4;
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1q_f32(outptrs[i][j], f[i][j]);
+ vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
outptrs[i][j] += 4;
}
}
@@ -166,7 +170,7 @@ void Transform::process_tile(
for (; channels_remaining >= 2; channels_remaining -= 2)
{
// Matrices used and computed during this transform
- float32x2_t F[6][6], FZ[6][4], f[4][4];
+ float32x2_t F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -211,11 +215,13 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = vld1_f32(bptr);
+ bptr += 2;
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1_f32(outptrs[i][j], f[i][j]);
+ vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
outptrs[i][j] += 2;
}
}
@@ -224,7 +230,7 @@ void Transform::process_tile(
for (; channels_remaining; channels_remaining--)
{
// Matrices used and computed during this transform
- float F[6][6], FZ[6][4], f[4][4];
+ float F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -255,11 +261,12 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = *(bptr++);
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- *(outptrs[i][j]++) = f[i][j];
+ *(outptrs[i][j]++) = f[i][j] + b;
}
}
}