diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-04-23 15:17:31 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:37 +0000 |
commit | e250389ed6d78153a55382fa5b3519c151bfd79f (patch) | |
tree | 80c63793769ad18fd0406e7f8b40840aed7ac3ce /tests/validation/reference/ConvolutionLayer.cpp | |
parent | 79ffadebd8dff7eaecbcfa3a28106736f240f1c5 (diff) | |
download | ComputeLibrary-e250389ed6d78153a55382fa5b3519c151bfd79f.tar.gz |
COMPMID-810 Add NHWC data format support for NEON convolution
Change-Id: I2a7b49a12da7f3bc3f04749243b1dc111160de6e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129348
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/reference/ConvolutionLayer.cpp')
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.cpp | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 617e85c8c2..fe558ba4af 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -26,6 +26,7 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" #include "tests/validation/reference/Convolution3d.h" +#include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/UtilsQuantizedAsymm.h" @@ -46,12 +47,9 @@ namespace } // namespace template <typename T, typename TB> -SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, - const Size2D &dilation) +SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info, + const Size2D &dilation) { - // Create reference - SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; - // Compute reference const int width_in = src.shape().x(); const int height_in = src.shape().y(); @@ -105,6 +103,26 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor return dst; } +template <typename T, typename TB> +SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, + const Size2D &dilation) +{ + // Create reference + SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; + + if(src.data_layout() == DataLayout::NHWC) + { + SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); + + return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation), PermutationVector(2U, 0U, 1U)); + } + else + { + return convolution_layer_nchw(src, weights, bias, dst, info, dilation); + } +} template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation); |