diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2021-06-30 18:29:18 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2021-07-06 11:03:31 +0000 |
commit | 900289936c458eff95499e0a0eaba989a27aaa4d (patch) | |
tree | 305853a38fd66842d19aa1a2d1cad88a70b946bc /tests/validation/fixtures | |
parent | 6132c7aeaf6230a4e8b074309327762a9e4be003 (diff) | |
download | ComputeLibrary-900289936c458eff95499e0a0eaba989a27aaa4d.tar.gz |
Port NEIm2ColKernel
Resolves: COMPMID-4510
Change-Id: Ia3e588f599449d975dabad4afafb2974dd44d0ad
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5899
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/Im2ColFixture.h | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b1fbd76eb2..38970116f6 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -45,6 +45,97 @@ namespace validation using namespace arm_compute::misc::shape_calculator; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> +class Im2ColOpValidationFixture : public framework::Fixture +{ +public: + template <typename...> + void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, + unsigned int num_groups) + { + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _quant_info = quant_info; + _data_layout = data_layout; + _has_bias = data_type != DataType::QASYMM8; + _num_groups = num_groups; + + if(_data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + } + + TensorInfo input_info(input_shape, 1, data_type); + input_info.set_data_layout(_data_layout); + + const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); + _target = compute_target(input_shape, output_shape, data_type); + + compute_reference(input_shape, output_shape, data_type); + } + +protected: + template <typename U> + void fill(U &&tensor) + { + library->fill_tensor_uniform(tensor, 0); + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create tensors + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, _quant_info, _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, _quant_info); + + // Create and configure function + FunctionType im2col_func; + im2col_func.configure(src.info(), dst.info(), _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), _num_groups); + + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(src)); + + arm_compute::ITensorPack pack = + { + { arm_compute::TensorType::ACL_SRC, &src }, + { arm_compute::TensorType::ACL_DST, &dst } + }; + // Compute function + im2col_func.run(pack); + + return dst; + } + + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create reference + SimpleTensor<T> src{ input_shape, data_type, 1, _quant_info, _data_layout }; + _reference = SimpleTensor<T>(output_shape, data_type, 1, _quant_info, DataLayout::NCHW); + + // Fill reference + fill(src); + + reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); + } + TensorType _target{}; + SimpleTensor<T> _reference{}; + Size2D _kernel_dims{}; + PadStrideInfo _conv_info{}; + DataLayout _data_layout{}; + QuantizationInfo _quant_info{}; + bool _has_bias{}; + unsigned int _num_groups{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> class Im2ColValidationFixture : public framework::Fixture { public: |