diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ConvolutionLayer.cpp | 28 | ||||
-rw-r--r-- | tests/validation/reference/Permute.cpp | 4 |
2 files changed, 25 insertions, 7 deletions
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 617e85c8c2..fe558ba4af 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -26,6 +26,7 @@ #include "tests/validation/FixedPoint.h" #include "tests/validation/Helpers.h" #include "tests/validation/reference/Convolution3d.h" +#include "tests/validation/reference/Permute.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/UtilsQuantizedAsymm.h" @@ -46,12 +47,9 @@ namespace } // namespace template <typename T, typename TB> -SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, - const Size2D &dilation) +SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const PadStrideInfo &info, + const Size2D &dilation) { - // Create reference - SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; - // Compute reference const int width_in = src.shape().x(); const int height_in = src.shape().y(); @@ -105,6 +103,26 @@ SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor return dst; } +template <typename T, typename TB> +SimpleTensor<T> convolution_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const TensorShape &output_shape, const PadStrideInfo &info, + const Size2D &dilation) +{ + // Create reference + SimpleTensor<T> dst{ output_shape, src.data_type(), 1, src.fixed_point_position(), src.quantization_info() }; + + if(src.data_layout() == DataLayout::NHWC) + { + SimpleTensor<T> src_nchw = reference::permute<T>(src, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U)); + SimpleTensor<T> dst_nchw = reference::permute<T>(dst, PermutationVector(1U, 2U, 0U)); + + return reference::permute<T>(convolution_layer_nchw(src_nchw, weights_nchw, bias, dst_nchw, info, dilation), PermutationVector(2U, 0U, 1U)); + } + else + { + return convolution_layer_nchw(src, weights, bias, dst, info, dilation); + } +} template SimpleTensor<float> convolution_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation); diff --git a/tests/validation/reference/Permute.cpp b/tests/validation/reference/Permute.cpp index c670c3ea6e..bbb2e8d4d7 100644 --- a/tests/validation/reference/Permute.cpp +++ b/tests/validation/reference/Permute.cpp @@ -57,11 +57,11 @@ SimpleTensor<T> permute(const SimpleTensor<T> &src, PermutationVector perm) return dst; } +template SimpleTensor<int8_t> permute(const SimpleTensor<int8_t> &src, PermutationVector perm); template SimpleTensor<uint8_t> permute(const SimpleTensor<uint8_t> &src, PermutationVector perm); +template SimpleTensor<int16_t> permute(const SimpleTensor<int16_t> &src, PermutationVector perm); template SimpleTensor<uint16_t> permute(const SimpleTensor<uint16_t> &src, PermutationVector perm); template SimpleTensor<uint32_t> permute(const SimpleTensor<uint32_t> &src, PermutationVector perm); -template SimpleTensor<int8_t> permute(const SimpleTensor<int8_t> &src, PermutationVector perm); -template SimpleTensor<int16_t> permute(const SimpleTensor<int16_t> &src, PermutationVector perm); template SimpleTensor<float> permute(const SimpleTensor<float> &src, PermutationVector perm); template SimpleTensor<half> permute(const SimpleTensor<half> &src, PermutationVector perm); } // namespace reference |