aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/SpaceToBatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/SpaceToBatch.cpp')
-rw-r--r--tests/validation/reference/SpaceToBatch.cpp44
1 files changed, 29 insertions, 15 deletions
diff --git a/tests/validation/reference/SpaceToBatch.cpp b/tests/validation/reference/SpaceToBatch.cpp
index 979ab94b33..c635d4abfd 100644
--- a/tests/validation/reference/SpaceToBatch.cpp
+++ b/tests/validation/reference/SpaceToBatch.cpp
@@ -39,38 +39,52 @@ SimpleTensor<T> space_to_batch(const SimpleTensor<T> &src, const SimpleTensor<in
{
SimpleTensor<T> result(dst_shape, src.data_type());
- auto width_out = static_cast<int>(dst_shape[0]);
- auto height_out = static_cast<int>(dst_shape[1]);
- auto z_out = static_cast<int>(dst_shape[2]);
+ const auto width_out = static_cast<int>(dst_shape[0]);
+ const auto height_out = static_cast<int>(dst_shape[1]);
+ const auto batch_out = static_cast<int>(dst_shape[3]);
+
+ const auto width_in = static_cast<int>(src.shape()[0]);
+ const auto height_in = static_cast<int>(src.shape()[1]);
+ const auto batch_in = static_cast<int>(src.shape()[3]);
+
+ const auto channel = static_cast<int>(src.shape()[2]);
+
+ const auto block_width = block_shape[0];
+ const auto block_height = block_shape[1];
+
+ const auto padding_left = paddings[0];
+ const auto padding_top = paddings[2];
int out_pos = 0;
- for(int batch = 0; batch < static_cast<int>(dst_shape[3]); ++batch)
+ for(int outB = 0; outB < batch_out; ++outB)
{
- for(int z = 0; z < z_out; ++z)
+ unsigned int inB = outB % batch_in;
+
+ int shift_w = (outB / batch_in) % block_width;
+ int shift_h = (outB / batch_in) / block_width;
+
+ for(int c = 0; c < channel; ++c)
{
- for(int y = 0; y < height_out; ++y)
+ for(int outH = 0; outH < height_out; ++outH)
{
- for(int x = 0; x < width_out; ++x)
+ for(int outW = 0; outW < width_out; ++outW)
{
- if(x < paddings[0] || x > width_out - paddings[1] - 1
- || y < paddings[2] || y > height_out - paddings[3] - 1)
+ const auto in_pos = ((inB * channel + c) * height_in + ((outH * block_height + shift_h) - padding_top)) * width_in + (outW * block_width + shift_w) - padding_left;
+
+ if(outH * block_height + shift_h < padding_top || outH * block_height + shift_h >= padding_top + height_in || outW * block_width + shift_w < padding_left
+ || outW * block_width + shift_w >= padding_left + width_in)
{
result[out_pos] = 0;
}
else
{
- const int r = dst_shape[3] / (block_shape[0] * block_shape[1]);
- const int in_x = (block_shape[0] * (x - paddings[0]) + (batch / r) % block_shape[0]);
- const int in_y = (block_shape[1] * (y - paddings[2]) + (batch / r) / block_shape[0]);
- int in_pos = in_x + src.shape()[0] * in_y + z * src.shape()[0] * src.shape()[1] + (batch % r) * src.shape()[0] * src.shape()[1] * src.shape()[2];
- result[out_pos] = src[in_pos];
+ result[out_pos] = src[in_pos];
}
++out_pos;
}
}
}
}
-
return result;
}