diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.cpp | 17 | ||||
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.h | 4 |
2 files changed, 13 insertions, 8 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index f41a6fc8c4..69090117fe 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -111,10 +111,15 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT } template <typename T, typename TB> SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, - const Size2D &dilation, unsigned int num_groups) + const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info) { + // if no explicit quantization has been set you the same as src + if(out_quant_info == QuantizationInfo()) + { + out_quant_info = src.quantization_info(); + } // Create reference - SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.quantization_info() }; + SimpleTensor<T> dst{ output_shape, src.data_type(), 1, out_quant_info }; if(src.data_layout() == DataLayout::NHWC) { @@ -131,11 +136,11 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor } template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape, - const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups); + const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, const TensorShape &output_shape, - const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups); + const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape, - const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups); + const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/ConvolutionLayer.h b/tests/validation/reference/ConvolutionLayer.h index ccce53a209..c51a9b3ad7 100644 --- a/tests/validation/reference/ConvolutionLayer.h +++ b/tests/validation/reference/ConvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,7 +37,7 @@ namespace reference { template <typename T, typename TB> SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, - const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1); + const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1, QuantizationInfo out_quant_info = QuantizationInfo()); } // namespace reference } // namespace validation } // namespace test |