diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/NEON/Im2Col.cpp | 49 | ||||
-rw-r--r-- | tests/validation/fixtures/Im2ColFixture.h | 91 |
2 files changed, 118 insertions, 22 deletions
diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp index 156957a601..f338675346 100644 --- a/tests/validation/NEON/Im2Col.cpp +++ b/tests/validation/NEON/Im2Col.cpp @@ -22,7 +22,7 @@ * SOFTWARE. */ #include "arm_compute/core/Types.h" -#include "src/core/NEON/kernels/NEIm2ColKernel.h" +#include "src/core/cpu/kernels/CpuIm2ColKernel.h" #include "tests/NEON/Accessor.h" #include "tests/NEON/Helper.h" #include "tests/datasets/ShapeDatasets.h" @@ -57,7 +57,7 @@ const auto conv_args_small = combine(combine(combine(combine(conv_filter TEST_SUITE(NEON) TEST_SUITE(Im2Col) -using NEIm2Col = NESynthetizeFunction<NEIm2ColKernel>; +using CpuIm2Col = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuIm2ColKernel>; // *INDENT-OFF* // clang-format off @@ -78,26 +78,26 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Expected", { false, false, false, false, true })), input_info, output_info, has_bias, expected) { - bool status = bool(NEIm2Col::validate(&input_info, &output_info, Size2D(3U, 3U), PadStrideInfo(), has_bias)); + bool status = bool(cpu::kernels::CpuIm2ColKernel::validate(&input_info, &output_info, Size2D(3U, 3U), PadStrideInfo(), has_bias)); ARM_COMPUTE_EXPECT(status == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* template <typename T> -using NEIm2ColFixture = Im2ColValidationFixture<Tensor, Accessor, NEIm2Col, T, false>; +using CpuIm2ColFixture = Im2ColOpValidationFixture<Tensor, Accessor, CpuIm2Col, T, false>; TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::F32)), - conv_args_small)) +FIXTURE_DATA_TEST_CASE(RunSmall, CpuIm2ColFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::F32)), + conv_args_small)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), framework::dataset::make("DataType", - DataType::F32)), - conv_args)) +FIXTURE_DATA_TEST_CASE(RunLarge, CpuIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), framework::dataset::make("DataType", + DataType::F32)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -107,15 +107,15 @@ TEST_SUITE_END() // FP32 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::F16)), - conv_args_small)) +FIXTURE_DATA_TEST_CASE(RunSmall, CpuIm2ColFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::F16)), + conv_args_small)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), framework::dataset::make("DataType", - DataType::F16)), - conv_args)) +FIXTURE_DATA_TEST_CASE(RunLarge, CpuIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), framework::dataset::make("DataType", + DataType::F16)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -127,15 +127,15 @@ TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args_small)) +FIXTURE_DATA_TEST_CASE(RunSmall, CpuIm2ColFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(im2col_shapes, framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args_small)) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), - framework::dataset::make("DataType", DataType::QASYMM8)), - conv_args)) +FIXTURE_DATA_TEST_CASE(RunLarge, CpuIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(concat(im2col_shapes, datasets::LargeShapes()), + framework::dataset::make("DataType", DataType::QASYMM8)), + conv_args)) { // Validate output validate(Accessor(_target), _reference); @@ -165,8 +165,8 @@ TEST_CASE(PaddedChannelNHWC, framework::DatasetMode::PRECOMMIT) Tensor dst_target = create_tensor<Tensor>(dst_shape, data_type, 1, qinfo); // Configure target function - NEIm2Col im2col_func; - im2col_func.configure(&src_target, &dst_target, spatial_kernel, conv_info, has_bias); + CpuIm2Col im2col_func; + im2col_func.configure(src_target.info(), dst_target.info(), spatial_kernel, conv_info, has_bias); // Extend padding src_target.info()->extend_padding(PaddingSize(3, 5, 9, 1)); @@ -185,8 +185,13 @@ TEST_CASE(PaddedChannelNHWC, framework::DatasetMode::PRECOMMIT) // Fill target source library->fill_tensor_uniform(Accessor(src_target), 0); + ITensorPack pack = + { + { TensorType::ACL_SRC, &src_target }, + { TensorType::ACL_DST, &dst_target } + }; // Run target function - im2col_func.run(); + im2col_func.run(pack); // Calculate Reference SimpleTensor<float> src_ref{ src_shape, data_type, 1, qinfo, data_layout }; diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b1fbd76eb2..38970116f6 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -45,6 +45,97 @@ namespace validation using namespace arm_compute::misc::shape_calculator; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> +class Im2ColOpValidationFixture : public framework::Fixture +{ +public: + template <typename...> + void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, + unsigned int num_groups) + { + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _quant_info = quant_info; + _data_layout = data_layout; + _has_bias = data_type != DataType::QASYMM8; + _num_groups = num_groups; + + if(_data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + } + + TensorInfo input_info(input_shape, 1, data_type); + input_info.set_data_layout(_data_layout); + + const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); + _target = compute_target(input_shape, output_shape, data_type); + + compute_reference(input_shape, output_shape, data_type); + } + +protected: + template <typename U> + void fill(U &&tensor) + { + library->fill_tensor_uniform(tensor, 0); + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create tensors + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, _quant_info, _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, _quant_info); + + // Create and configure function + FunctionType im2col_func; + im2col_func.configure(src.info(), dst.info(), _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), _num_groups); + + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(src)); + + arm_compute::ITensorPack pack = + { + { arm_compute::TensorType::ACL_SRC, &src }, + { arm_compute::TensorType::ACL_DST, &dst } + }; + // Compute function + im2col_func.run(pack); + + return dst; + } + + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create reference + SimpleTensor<T> src{ input_shape, data_type, 1, _quant_info, _data_layout }; + _reference = SimpleTensor<T>(output_shape, data_type, 1, _quant_info, DataLayout::NCHW); + + // Fill reference + fill(src); + + reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); + } + TensorType _target{}; + SimpleTensor<T> _reference{}; + Size2D _kernel_dims{}; + PadStrideInfo _conv_info{}; + DataLayout _data_layout{}; + QuantizationInfo _quant_info{}; + bool _has_bias{}; + unsigned int _num_groups{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> class Im2ColValidationFixture : public framework::Fixture { public: |