aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/DeconvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/DeconvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/DeconvolutionLayerFixture.h24
1 files changed, 22 insertions, 2 deletions
diff --git a/tests/validation/fixtures/DeconvolutionLayerFixture.h b/tests/validation/fixtures/DeconvolutionLayerFixture.h
index 9f90f07c97..a25a65f997 100644
--- a/tests/validation/fixtures/DeconvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DeconvolutionLayerFixture.h
@@ -218,7 +218,27 @@ public:
const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels);
const TensorShape bias_shape(num_kernels);
const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL);
- auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, padx, pady, sx, sy);
+ auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info);
+ TensorInfo input_info(input_shape, 1, data_type);
+ TensorInfo weights_info(weights_shape, 1, data_type);
+ TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info);
+ DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, data_layout, QuantizationInfo(), add_bias);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, unsigned int kernel_size_x, unsigned int kernel_size_y>
+class DeconvolutionValidationAsymmFixture : public DeconvolutionLayerFixtureBase<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, unsigned int sx, unsigned int sy, unsigned int pad_left, unsigned int pad_right, unsigned int pad_top,
+ unsigned int pad_bottom, unsigned int num_kernels, DataType data_type, DataLayout data_layout, bool add_bias)
+ {
+ ARM_COMPUTE_ERROR_ON_MSG(kernel_size_x != kernel_size_y, "Only square kernels supported");
+ const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels);
+ const TensorShape bias_shape(num_kernels);
+ const PadStrideInfo info(sx, sy, pad_left, pad_right, pad_top, pad_bottom, DimensionRoundingType::CEIL);
+ auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info);
TensorInfo input_info(input_shape, 1, data_type);
TensorInfo weights_info(weights_shape, 1, data_type);
TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info);
@@ -238,7 +258,7 @@ public:
const TensorShape weights_shape(kernel_size_x, kernel_size_y, input_shape.z(), num_kernels);
const TensorShape bias_shape(num_kernels);
const PadStrideInfo info(sx, sy, padx, pady, DimensionRoundingType::CEIL);
- auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, padx, pady, sx, sy);
+ auto out_dim = deconvolution_output_dimensions(input_shape.x(), input_shape.y(), kernel_size_x, kernel_size_y, info);
TensorInfo input_info(input_shape, 1, data_type, quantization_info);
TensorInfo weights_info(weights_shape, 1, data_type, quantization_info);
TensorShape output_shape = compute_deconvolution_output_shape(out_dim, input_info, weights_info);