From 75bde5e21cfbf5e699a3a89655d97fec7c0892e7 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 7 Jun 2019 11:52:01 +0100 Subject: COMPMID-2336: Account for padding in NEIm2ColKernel for NHWC. Change-Id: I494c4acc95cb431b1718ae62c1504522a115ba10 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1312 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Reviewed-by: Giuseppe Rossini Tested-by: Arm Jenkins --- tests/validation/NEON/Im2Col.cpp | 62 +++++++++++++++++++++++++++++++++++ tests/validation/reference/Im2Col.cpp | 2 +- tests/validation/reference/Im2Col.h | 4 +-- 3 files changed, 65 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp index 0d00c0a167..f4b2cc7835 100644 --- a/tests/validation/NEON/Im2Col.cpp +++ b/tests/validation/NEON/Im2Col.cpp @@ -137,6 +137,68 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture, framework::DatasetMod } TEST_SUITE_END() // QASYMM8 +TEST_SUITE(SpecialCases) +TEST_CASE(PaddedChannelNHWC, framework::DatasetMode::PRECOMMIT) +{ + // Const data + const TensorShape src_shape = TensorShape(7U, 27U, 13U); + const DataType data_type = DataType::F32; + const DataLayout data_layout = DataLayout::NHWC; + const bool has_bias = false; + const unsigned int num_groups = 1; + const Size2D spatial_kernel(3, 3); + const QuantizationInfo qinfo{}; + const PadStrideInfo conv_info(1U, 1U, 0U, 0U); + + // Calculate destination shape + TensorInfo src_info(src_shape, 1, data_type); + src_info.set_data_layout(data_layout); + const TensorShape dst_shape = compute_im2col_conv_shape(&src_info, spatial_kernel, conv_info, has_bias, Size2D(1U, 1U), false, num_groups); + + // Compute target + Tensor src_target = create_tensor(src_shape, data_type, 1, qinfo, data_layout); + Tensor dst_target = create_tensor(dst_shape, data_type, 1, qinfo); + + // Configure target function + NEIm2Col im2col_func; + im2col_func.configure(&src_target, &dst_target, spatial_kernel, conv_info, has_bias); + + // Extend padding + src_target.info()->extend_padding(PaddingSize(3, 5, 9, 1)); + dst_target.info()->extend_padding(PaddingSize(8, 1, 1, 3)); + + // Validate and allocate tensors + ARM_COMPUTE_EXPECT(src_target.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst_target.info()->is_resizable(), framework::LogLevel::ERRORS); + + src_target.allocator()->allocate(); + dst_target.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src_target.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst_target.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill target source + library->fill_tensor_uniform(Accessor(src_target), 0); + + // Run target function + im2col_func.run(); + + // Calculate Reference + SimpleTensor src_ref{ src_shape, data_type, 1, qinfo, data_layout }; + SimpleTensor dst_ref{ dst_shape, data_type, 1, qinfo, DataLayout::NCHW }; + + // Fill reference source + library->fill_tensor_uniform(src_ref, 0); + +#ifndef DOXYGEN_SKIP_THIS + // Run reference function + reference::im2col(src_ref, dst_ref, spatial_kernel, conv_info, has_bias, num_groups); +#endif // DOXYGEN_SKIP_THIS + + // Validate + validate(Accessor(dst_target), dst_ref); +} +TEST_SUITE_END() // Special Cases TEST_SUITE_END() // Im2Col TEST_SUITE_END() // NEON } // namespace validation diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 4d63696e67..4b41cdb70b 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -139,7 +139,7 @@ void im2col_nhwc(const SimpleTensor &src, SimpleTensor &dst, const Size2D } template -void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups) +void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups) { switch(src.data_layout()) { diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h index f519d0e602..34b8476a46 100644 --- a/tests/validation/reference/Im2Col.h +++ b/tests/validation/reference/Im2Col.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,7 @@ namespace validation namespace reference { template -void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const unsigned int num_groups); +void im2col(const SimpleTensor &src, SimpleTensor &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1