diff options
Diffstat (limited to 'tests/validation/fixtures/Im2ColFixture.h')
-rw-r--r-- | tests/validation/fixtures/Im2ColFixture.h | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h index b1fbd76eb2..38970116f6 100644 --- a/tests/validation/fixtures/Im2ColFixture.h +++ b/tests/validation/fixtures/Im2ColFixture.h @@ -45,6 +45,97 @@ namespace validation using namespace arm_compute::misc::shape_calculator; template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> +class Im2ColOpValidationFixture : public framework::Fixture +{ +public: + template <typename...> + void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout, + unsigned int num_groups) + { + _kernel_dims = kernel_dims; + _conv_info = conv_info; + _quant_info = quant_info; + _data_layout = data_layout; + _has_bias = data_type != DataType::QASYMM8; + _num_groups = num_groups; + + if(_data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + } + + TensorInfo input_info(input_shape, 1, data_type); + input_info.set_data_layout(_data_layout); + + const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z && _num_groups == 1, _num_groups); + _target = compute_target(input_shape, output_shape, data_type); + + compute_reference(input_shape, output_shape, data_type); + } + +protected: + template <typename U> + void fill(U &&tensor) + { + library->fill_tensor_uniform(tensor, 0); + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create tensors + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, _quant_info, _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, _quant_info); + + // Create and configure function + FunctionType im2col_func; + im2col_func.configure(src.info(), dst.info(), _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), _num_groups); + + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(src)); + + arm_compute::ITensorPack pack = + { + { arm_compute::TensorType::ACL_SRC, &src }, + { arm_compute::TensorType::ACL_DST, &dst } + }; + // Compute function + im2col_func.run(pack); + + return dst; + } + + void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type) + { + // Create reference + SimpleTensor<T> src{ input_shape, data_type, 1, _quant_info, _data_layout }; + _reference = SimpleTensor<T>(output_shape, data_type, 1, _quant_info, DataLayout::NCHW); + + // Fill reference + fill(src); + + reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, _num_groups); + } + TensorType _target{}; + SimpleTensor<T> _reference{}; + Size2D _kernel_dims{}; + PadStrideInfo _conv_info{}; + DataLayout _data_layout{}; + QuantizationInfo _quant_info{}; + bool _has_bias{}; + unsigned int _num_groups{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool batch_size_on_z> class Im2ColValidationFixture : public framework::Fixture { public: |