aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/WinogradLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h85
1 files changed, 38 insertions, 47 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
index 481eb93e80..17229cac25 100644
--- a/tests/validation/fixtures/WinogradLayerFixture.h
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -142,8 +142,9 @@ protected:
fill(bias, 2, 0.f, 0.f);
}
- return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info), act_info) : reference::convolution_layer<T>(src, weights, bias,
- output_shape, info);
+ SimpleTensor<T> conv_out = reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+
+ return (act_info.enabled()) ? reference::activation_layer<T>(conv_out, act_info) : conv_out;
}
TensorType _target{};
@@ -155,12 +156,12 @@ class WinogradInputTransformValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, PadStrideInfo conv_info, Size2D kernel_dims, bool is_nchw_format, DataType data_type)
+ void setup(TensorShape input_shape, WinogradInfo winograd_info, DataLayout data_layout, DataType data_type)
{
- TensorShape output_shape = compute_winograd_input_transform_shape(TensorInfo(input_shape, 1, data_type), conv_info, kernel_dims);
+ TensorShape output_shape = compute_winograd_input_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
- _target = compute_target(input_shape, output_shape, conv_info, kernel_dims, is_nchw_format, data_type);
- _reference = compute_reference(input_shape, output_shape, conv_info, kernel_dims, is_nchw_format, data_type);
+ _target = compute_target(input_shape, output_shape, winograd_info, data_layout, data_type);
+ _reference = compute_reference(input_shape, output_shape, winograd_info, data_layout, data_type);
}
protected:
@@ -184,16 +185,14 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const PadStrideInfo &conv_info, const Size2D &kernel_dims, bool is_nchw_format, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataLayout data_layout, DataType data_type)
{
- ARM_COMPUTE_UNUSED(is_nchw_format);
-
- TensorType src = create_tensor<TensorType>(input_shape, data_type);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
// Create and configure function
FunctionType transf;
- transf.configure(&src, &dst, conv_info, kernel_dims);
+ transf.configure(&src, &dst, winograd_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -208,23 +207,21 @@ protected:
// Fill tensors
fill(AccessorType(src), 0, -1.f, 1.f);
- // Compute CLWinogradInputTransform function
+ // Compute Winograd input transform function
transf.run();
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const PadStrideInfo &conv_info, const Size2D &kernel_dims, bool is_nchw_format, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataLayout data_layout, DataType data_type)
{
- ARM_COMPUTE_UNUSED(is_nchw_format);
-
// Create reference
- SimpleTensor<T> src{ input_shape, data_type };
+ SimpleTensor<T> src{ input_shape, data_type, 1, 0, QuantizationInfo(), data_layout };
// Fill reference
fill(src, 0, -1.f, 1.f);
- return reference::winograd_input_transform<T>(src, output_shape, conv_info, kernel_dims);
+ return reference::winograd_input_transform<T>(src, output_shape, winograd_info);
}
TensorType _target{};
@@ -236,12 +233,13 @@ class WinogradFilterTransformValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, bool is_nchw_format, Size2D output_tile, DataType data_type)
+ void setup(TensorShape input_shape, Size2D output_tile, DataLayout data_layout, DataType data_type)
{
- TensorShape output_shape = compute_winograd_filter_transform_shape(TensorInfo(input_shape, 1, data_type), output_tile);
+ WinogradInfo winograd_info(output_tile, Size2D(input_shape[0], input_shape[1]), Size2D() /* Not needed */, PadStrideInfo() /* Not needed */, DataLayout::NCHW /* Not needed */);
+ TensorShape output_shape = compute_winograd_filter_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
- _target = compute_target(input_shape, output_shape, is_nchw_format, output_tile, data_type);
- _reference = compute_reference(input_shape, output_shape, is_nchw_format, output_tile, data_type);
+ _target = compute_target(input_shape, output_shape, winograd_info, data_layout, data_type);
+ _reference = compute_reference(input_shape, output_shape, winograd_info, data_layout, data_type);
}
protected:
@@ -265,17 +263,15 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, const Size2D &output_tile, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataLayout data_layout, DataType data_type)
{
- ARM_COMPUTE_UNUSED(is_nchw_format);
-
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type, 1);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
// Create and configure function
FunctionType filter_transform;
- filter_transform.configure(&src, &dst, output_tile);
+ filter_transform.configure(&src, &dst, winograd_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -295,17 +291,15 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, const Size2D &output_tile, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataLayout data_layout, DataType data_type)
{
- ARM_COMPUTE_UNUSED(is_nchw_format);
-
// Create reference
- SimpleTensor<T> src{ input_shape, data_type, 1 };
+ SimpleTensor<T> src{ input_shape, data_type, 1, 0, QuantizationInfo(), data_layout };
// Fill reference
fill(src, 0, -1.f, 1.f);
- return reference::winograd_filter_transform<T>(src, output_shape, output_tile);
+ return reference::winograd_filter_transform<T>(src, output_shape, winograd_info);
}
TensorType _target{};
@@ -317,12 +311,12 @@ class WinogradOutputTransformValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, Size2D kernel_dims, Size2D output_convolved_dims, Size2D num_tiles, DataLayout data_layout, DataType data_type)
+ void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type)
{
- TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), output_convolved_dims, data_layout);
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
- _target = compute_target(input_shape, output_shape, kernel_dims, output_convolved_dims, num_tiles, data_layout, data_type);
- _reference = compute_reference(input_shape, output_shape, kernel_dims, output_convolved_dims, num_tiles, data_layout, data_type);
+ _target = compute_target(input_shape, output_shape, winograd_info, data_type);
+ _reference = compute_reference(input_shape, output_shape, winograd_info, data_type);
}
protected:
@@ -346,16 +340,15 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &output_convolved_dims, Size2D &num_tiles, DataLayout data_layout,
- DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), winograd_info.output_data_layout);
// Create and configure function
FunctionType output_transform;
- output_transform.configure(&src, nullptr, &dst, kernel_dims, output_convolved_dims, num_tiles);
+ output_transform.configure(&src, nullptr, &dst, winograd_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -375,17 +368,15 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &output_convolved_dims, Size2D &num_tiles,
- DataLayout data_layout,
- DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
{
// Create reference
- SimpleTensor<T> src{ input_shape, data_type, 1, 0, QuantizationInfo(), data_layout };
+ SimpleTensor<T> src{ input_shape, data_type };
// Fill reference
fill(src, 0, -1.f, 1.f);
- return reference::winograd_output_transform<T>(src, output_shape, kernel_dims, num_tiles);
+ return reference::winograd_output_transform<T>(src, output_shape, winograd_info);
}
TensorType _target{};