diff options
Diffstat (limited to 'tests/validation/reference/ConvolutionLayer.cpp')
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.cpp | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 69090117fe..4d2c1acb6f 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -45,8 +45,8 @@ namespace { } // namespace -template <typename T, typename TB> -SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info, +template <typename T, typename TW, typename TB> +SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups) { ARM_COMPUTE_ERROR_ON((src.shape()[2] / num_groups) != weights.shape()[2]); @@ -73,7 +73,6 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT const int end_xi = output_wh.first * stride_xi; const int end_yi = output_wh.second * stride_yi; const int num_batches = src.shape().total_size() / (width_in * height_in * depth_in); - for(int r = 0; r < num_batches; ++r) { for(int yi = start_yi; yi < start_yi + end_yi; yi += stride_yi) @@ -100,17 +99,16 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT offset_in, offset_w, offset_b, offset_out, xi, yi, width_in, height_in, (depth_in / num_groups), - width_weights, height_weights, dilation.x(), dilation.y()); + width_weights, height_weights, dilation.x(), dilation.y(), ofm); } } } } } - return dst; } -template <typename T, typename TB> -SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, +template <typename T, typename TW, typename TB> +SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info) { // if no explicit quantization has been set you the same as src @@ -123,9 +121,9 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor if(src.data_layout() == DataLayout::NHWC) { - SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); - SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U)); - SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor<TW> weights_nchw = reference::permute<TW>(weights, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation, num_groups), PermutationVector(2U, 0U, 1U)); } @@ -141,6 +139,8 @@ template SimpleTensor<half> convolution_layer(const SimpleTensor<half> &src, con const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); +template SimpleTensor<uint8_t> convolution_layer(const SimpleTensor<uint8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, const TensorShape &output_shape, + const PadStrideInfo &info, const Size2D &dilation, unsigned int num_groups, QuantizationInfo out_quant_info); } // namespace reference } // namespace validation } // namespace test |