aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/reference/Convolution.cpp30
1 files changed, 16 insertions, 14 deletions
diff --git a/tests/validation/reference/Convolution.cpp b/tests/validation/reference/Convolution.cpp
index 308c0b5a87..5d0cf1ea93 100644
--- a/tests/validation/reference/Convolution.cpp
+++ b/tests/validation/reference/Convolution.cpp
@@ -43,22 +43,25 @@ SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_da
SimpleTensor<T> dst(src.shape(), output_data_type);
- for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx)
+ switch(output_data_type)
{
- const Coordinates id = index2coord(src.shape(), element_idx);
-
- switch(output_data_type)
+ case DataType::S16:
{
- case DataType::S16:
+ SimpleTensor<int16_t> sum(src.shape(), output_data_type);
+ for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx)
{
- SimpleTensor<int16_t> sum(src.shape(), output_data_type);
+ const Coordinates id = index2coord(src.shape(), element_idx);
apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1 / static_cast<double>(scale), border_mode, constant_border_value);
dst[element_idx] = tensor_elem_at<int16_t>(sum, id, border_mode, constant_border_value);
}
- break;
- case DataType::U8:
+ }
+ break;
+ case DataType::U8:
+ {
+ SimpleTensor<int32_t> sum(src.shape(), output_data_type);
+ for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx)
{
- SimpleTensor<int32_t> sum(src.shape(), output_data_type);
+ const Coordinates id = index2coord(src.shape(), element_idx);
apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1, border_mode, constant_border_value);
if(tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) < 0)
{
@@ -73,13 +76,12 @@ SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_da
dst[element_idx] = tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale;
}
}
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported DataType");
- break;
}
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported DataType");
+ break;
}
-
return dst;
}