aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/BatchToSpaceLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/BatchToSpaceLayerFixture.h')
-rw-r--r--tests/validation/fixtures/BatchToSpaceLayerFixture.h29
1 files changed, 20 insertions, 9 deletions
diff --git a/tests/validation/fixtures/BatchToSpaceLayerFixture.h b/tests/validation/fixtures/BatchToSpaceLayerFixture.h
index 6554c09de4..5a23261a6e 100644
--- a/tests/validation/fixtures/BatchToSpaceLayerFixture.h
+++ b/tests/validation/fixtures/BatchToSpaceLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,14 +36,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class BatchToSpaceLayerValidationFixture : public framework::Fixture
+class BatchToSpaceLayerValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout)
+ void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout, const CropInfo &crop_info = CropInfo{})
{
- _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout);
- _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type);
+ _target = compute_target(input_shape, block_shape_shape, output_shape, data_type, data_layout, crop_info);
+ _reference = compute_reference(input_shape, block_shape_shape, output_shape, data_type, crop_info);
}
protected:
@@ -57,7 +57,7 @@ protected:
library->fill(tensor, distribution, i);
}
TensorType compute_target(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape,
- DataType data_type, DataLayout data_layout)
+ DataType data_type, DataLayout data_layout, const CropInfo &crop_info)
{
if(data_layout == DataLayout::NHWC)
{
@@ -72,7 +72,7 @@ protected:
// Create and configure function
FunctionType batch_to_space;
- batch_to_space.configure(&input, &block_shape, &output);
+ batch_to_space.configure(&input, &block_shape, &output, crop_info);
ARM_COMPUTE_ASSERT(input.info()->is_resizable());
ARM_COMPUTE_ASSERT(block_shape.info()->is_resizable());
@@ -104,7 +104,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &block_shape_shape,
- const TensorShape &output_shape, DataType data_type)
+ const TensorShape &output_shape, DataType data_type, const CropInfo &crop_info)
{
// Create reference
SimpleTensor<T> input{ input_shape, data_type };
@@ -118,12 +118,23 @@ protected:
}
// Compute reference
- return reference::batch_to_space(input, block_shape, output_shape);
+ return reference::batch_to_space(input, block_shape, output_shape, crop_info);
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class BatchToSpaceLayerValidationFixture : public BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape block_shape_shape, TensorShape output_shape, DataType data_type, DataLayout data_layout)
+ {
+ BatchToSpaceLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, block_shape_shape, output_shape, data_type, data_layout, CropInfo{});
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute