diff options
-rw-r--r-- | tests/validation/reference/Convolution.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/tests/validation/reference/Convolution.cpp b/tests/validation/reference/Convolution.cpp index 308c0b5a87..5d0cf1ea93 100644 --- a/tests/validation/reference/Convolution.cpp +++ b/tests/validation/reference/Convolution.cpp @@ -43,22 +43,25 @@ SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_da SimpleTensor<T> dst(src.shape(), output_data_type); - for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx) + switch(output_data_type) { - const Coordinates id = index2coord(src.shape(), element_idx); - - switch(output_data_type) + case DataType::S16: { - case DataType::S16: + SimpleTensor<int16_t> sum(src.shape(), output_data_type); + for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx) { - SimpleTensor<int16_t> sum(src.shape(), output_data_type); + const Coordinates id = index2coord(src.shape(), element_idx); apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1 / static_cast<double>(scale), border_mode, constant_border_value); dst[element_idx] = tensor_elem_at<int16_t>(sum, id, border_mode, constant_border_value); } - break; - case DataType::U8: + } + break; + case DataType::U8: + { + SimpleTensor<int32_t> sum(src.shape(), output_data_type); + for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx) { - SimpleTensor<int32_t> sum(src.shape(), output_data_type); + const Coordinates id = index2coord(src.shape(), element_idx); apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1, border_mode, constant_border_value); if(tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) < 0) { @@ -73,13 +76,12 @@ SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_da dst[element_idx] = tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale; } } - break; - default: - ARM_COMPUTE_ERROR("Not supported DataType"); - break; } + break; + default: + ARM_COMPUTE_ERROR("Not supported DataType"); + break; } - return dst; } |