aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/Im2ColFixture.h
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-04-23 16:11:45 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:54 +0000
commit764b1af5f41749624c3900fc65c9dad3b1adf98c (patch)
tree9d8397de05d86eb40959d24058939c46c36f0d39 /tests/validation/fixtures/Im2ColFixture.h
parent3ca9786fe8ed00ad03963cae6a9eef7bb2fe630e (diff)
downloadComputeLibrary-764b1af5f41749624c3900fc65c9dad3b1adf98c.tar.gz
COMPMID-1070: Rewrote im2col nhwc reference
The new reference computes directly the results without the need of using temporary tensors and perform permutations which is problematic for big 4k tensors as we get bad_alloc execptions. Change-Id: I1fb0a495c4e9ca3356de76c9298832db8ace794a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128683 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/Im2ColFixture.h')
-rw-r--r--tests/validation/fixtures/Im2ColFixture.h29
1 files changed, 14 insertions, 15 deletions
diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h
index f403aa9d21..b6e4cd0023 100644
--- a/tests/validation/fixtures/Im2ColFixture.h
+++ b/tests/validation/fixtures/Im2ColFixture.h
@@ -49,7 +49,7 @@ class Im2ColValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout)
+ void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout)
{
_kernel_dims = kernel_dims;
_conv_info = conv_info;
@@ -59,16 +59,17 @@ public:
if(_data_layout == DataLayout::NHWC)
{
- permute(shape, PermutationVector(2U, 0U, 1U));
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
}
- TensorShape output_shape;
- TensorInfo input_info(shape, 1, data_type);
+ TensorInfo input_info(input_shape, 1, data_type);
input_info.set_data_layout(_data_layout);
- output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U));
- _target = compute_target(shape, output_shape, data_type);
- _reference = compute_reference(shape, output_shape, data_type);
+ const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U));
+
+ _target = compute_target(input_shape, output_shape, data_type);
+
+ compute_reference(input_shape, output_shape, data_type);
}
protected:
@@ -78,10 +79,10 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, const TensorShape &output_shape, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type, 1, 0, _quant_info, _data_layout);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, _quant_info, _data_layout);
TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, _quant_info, _data_layout);
// Create and configure function
@@ -107,17 +108,15 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &output_shape, DataType data_type)
+ void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type, 1, 0, _quant_info, _data_layout };
-
+ SimpleTensor<T> src{ input_shape, data_type, 1, 0, _quant_info, _data_layout };
+ _reference = SimpleTensor<T>(output_shape, data_type, 1, 0, _quant_info, DataLayout::NCHW);
// Fill reference
fill(src);
-
- return reference::im2col<T>(src, output_shape, _kernel_dims, _conv_info, _has_bias);
+ reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias);
}
-
TensorType _target{};
SimpleTensor<T> _reference{};
Size2D _kernel_dims{};