diff options
Diffstat (limited to 'tests/validation/reference/ConvolutionLayer.cpp')
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.cpp | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 617e85c8c2..fe558ba4af 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -26,6 +26,7 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" #include "tests/validation/reference/Convolution3d.h" +#include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/UtilsQuantizedAsymm.h" @@ -46,12 +47,9 @@ namespace } // namespace template <typename T, typename TB> -SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, - const Size2D &dilation) +SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info, + const Size2D &dilation) { - // Create reference - SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; - // Compute reference const int width_in = src.shape().x(); const int height_in = src.shape().y(); @@ -105,6 +103,26 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor return dst; } +template <typename T, typename TB> +SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, + const Size2D &dilation) +{ + // Create reference + SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; + + if(src.data_layout() == DataLayout::NHWC) + { + SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); + + return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation), PermutationVector(2U, 0U, 1U)); + } + else + { + return convolution_layer_nchw(src, weights, bias, dst, info, dilation); + } +} template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation); |