diff options
Diffstat (limited to 'tests/validation/reference/DepthwiseConvolutionLayer.cpp')
-rw-r--r-- | tests/validation/reference/DepthwiseConvolutionLayer.cpp | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tests/validation/reference/DepthwiseConvolutionLayer.cpp b/tests/validation/reference/DepthwiseConvolutionLayer.cpp index f27610afb8..122dbd4d98 100644 --- a/tests/validation/reference/DepthwiseConvolutionLayer.cpp +++ b/tests/validation/reference/DepthwiseConvolutionLayer.cpp @@ -50,8 +50,10 @@ namespace reference */ template <typename T, typename TB> SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, const Size2D &dilation) + unsigned int depth_multiplier, const Size2D &dilation, QuantizationInfo out_quant_info) { + ARM_COMPUTE_UNUSED(out_quant_info); + SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 }; // Compute reference @@ -119,9 +121,14 @@ SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTe template <> SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &biases, const TensorShape &dst_shape, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation) + const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, QuantizationInfo out_quant_info) { - SimpleTensor<uint8_t> dst{ dst_shape, src.data_type(), 1, src.quantization_info() }; + // if no explicit quantization has been set you the same as src + if(out_quant_info == QuantizationInfo(0.0f, 0)) + { + out_quant_info = src.quantization_info(); + } + SimpleTensor<uint8_t> dst{ dst_shape, src.data_type(), 1, out_quant_info }; // Create reference const int input_offset = -src.quantization_info().offset; @@ -206,10 +213,10 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co } template SimpleTensor<float> depthwise_convolution(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &biases, const TensorShape &dst_shape, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation); + const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, QuantizationInfo out_quant_info); template SimpleTensor<half> depthwise_convolution(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &biases, const TensorShape &dst_shape, - const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation); + const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, QuantizationInfo out_quant_info); } // namespace reference } // namespace validation } // namespace test |