From 900289936c458eff95499e0a0eaba989a27aaa4d Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 30 Jun 2021 18:29:18 +0100 Subject: Port NEIm2ColKernel Resolves: COMPMID-4510 Change-Id: Ia3e588f599449d975dabad4afafb2974dd44d0ad Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5899 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- tests/validation/fixtures/Im2ColFixture.h | 91 +++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) (limited to 'tests/validation/fixtures/Im2ColFixture.h') diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b1fbd76eb2..38970116f6 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -44,6 +44,97 @@ namespace validation { using namespace arm_compute::misc::shape_calculator; +template +class Im2ColOpValidationFixture : public framework::Fixture +{ +public: + template + void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, + unsigned int num_groups) + { + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _quant_info = quant_info; + _data_layout = data_layout; + _has_bias = data_type != DataType::QASYMM8; + _num_groups = num_groups; + + if(_data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + } + + TensorInfo input_info(input_shape, 1, data_type); + input_info.set_data_layout(_data_layout); + + const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); + _target = compute_target(input_shape, output_shape, data_type); + + compute_reference(input_shape, output_shape, data_type); + } + +protected: + template + void fill(U &&tensor) + { + library->fill_tensor_uniform(tensor, 0); + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create tensors + TensorType src = create_tensor(input_shape, data_type, 1, _quant_info, _data_layout); + TensorType dst = create_tensor(output_shape, data_type, 1, _quant_info); + + // Create and configure function + FunctionType im2col_func; + im2col_func.configure(src.info(), dst.info(), _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), _num_groups); + + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(src)); + + arm_compute::ITensorPack pack = + { + { arm_compute::TensorType::ACL_SRC, &src }, + { arm_compute::TensorType::ACL_DST, &dst } + }; + // Compute function + im2col_func.run(pack); + + return dst; + } + + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create reference + SimpleTensor src{ input_shape, data_type, 1, _quant_info, _data_layout }; + _reference = SimpleTensor(output_shape, data_type, 1, _quant_info, DataLayout::NCHW); + + // Fill reference + fill(src); + + reference::im2col(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); + } + TensorType _target{}; + SimpleTensor _reference{}; + Size2D _kernel_dims{}; + PadStrideInfo _conv_info{}; + DataLayout _data_layout{}; + QuantizationInfo _quant_info{}; + bool _has_bias{}; + unsigned int _num_groups{}; +}; + template class Im2ColValidationFixture : public framework::Fixture { -- cgit v1.2.1