aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/Convolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/Convolution.cpp')
-rw-r--r--tests/validation/reference/Convolution.cpp55
1 files changed, 39 insertions, 16 deletions
diff --git a/tests/validation/reference/Convolution.cpp b/tests/validation/reference/Convolution.cpp
index 777e2df400..308c0b5a87 100644
--- a/tests/validation/reference/Convolution.cpp
+++ b/tests/validation/reference/Convolution.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,34 +35,57 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> convolution(const SimpleTensor<T> &src, const int16_t *conv, uint32_t scale, BorderMode border_mode, T constant_border_value, const unsigned int width, const unsigned int height)
+SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_data_type, const int16_t *conv, uint32_t scale, BorderMode border_mode, uint8_t constant_border_value,
+ const unsigned int width,
+ const unsigned int height)
{
- SimpleTensor<T> dst(src.shape(), src.data_type());
- SimpleTensor<int32_t> sum(src.shape(), src.data_type());
+ ARM_COMPUTE_ERROR_ON(0 == scale);
+
+ SimpleTensor<T> dst(src.shape(), output_data_type);
for(int element_idx = 0; element_idx < src.num_elements(); ++element_idx)
{
const Coordinates id = index2coord(src.shape(), element_idx);
- apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1, border_mode, constant_border_value);
- if(tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) < 0)
- {
- dst[element_idx] = 0;
- }
- else if((tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale) > 255)
+ switch(output_data_type)
{
- dst[element_idx] = 255;
- }
- else
- {
- dst[element_idx] = tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale;
+ case DataType::S16:
+ {
+ SimpleTensor<int16_t> sum(src.shape(), output_data_type);
+ apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1 / static_cast<double>(scale), border_mode, constant_border_value);
+ dst[element_idx] = tensor_elem_at<int16_t>(sum, id, border_mode, constant_border_value);
+ }
+ break;
+ case DataType::U8:
+ {
+ SimpleTensor<int32_t> sum(src.shape(), output_data_type);
+ apply_2d_spatial_filter(id, src, sum, TensorShape(width, height), conv, 1, border_mode, constant_border_value);
+ if(tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) < 0)
+ {
+ dst[element_idx] = 0;
+ }
+ else if((tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale) > 255)
+ {
+ dst[element_idx] = 255;
+ }
+ else
+ {
+ dst[element_idx] = tensor_elem_at<int32_t>(sum, id, border_mode, constant_border_value) / scale;
+ }
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported DataType");
+ break;
}
}
return dst;
}
-template SimpleTensor<uint8_t> convolution(const SimpleTensor<uint8_t> &src, const int16_t *conv, uint32_t scale, BorderMode border_mode, uint8_t constant_border_value,
+template SimpleTensor<uint8_t> convolution(const SimpleTensor<uint8_t> &src, DataType output_data_type, const int16_t *conv, uint32_t scale, BorderMode border_mode, uint8_t constant_border_value,
+ const unsigned int widht, const unsigned int height);
+template SimpleTensor<int16_t> convolution(const SimpleTensor<uint8_t> &src, DataType output_data_type, const int16_t *conv, uint32_t scale, BorderMode border_mode, uint8_t constant_border_value,
const unsigned int widht, const unsigned int height);
} // namespace reference
} // namespace validation