aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ReshapeLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ReshapeLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ReshapeLayerFixture.h43
1 files changed, 34 insertions, 9 deletions
diff --git a/tests/validation/fixtures/ReshapeLayerFixture.h b/tests/validation/fixtures/ReshapeLayerFixture.h
index 8a98379ef2..5be431f8cf 100644
--- a/tests/validation/fixtures/ReshapeLayerFixture.h
+++ b/tests/validation/fixtures/ReshapeLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_RESHAPE_LAYER_FIXTURE
-#define ARM_COMPUTE_TEST_RESHAPE_LAYER_FIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_RESHAPELAYERFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_RESHAPELAYERFIXTURE_H
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
@@ -31,6 +31,7 @@
#include "tests/IAccessor.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/reference/ReshapeLayer.h"
namespace arm_compute
@@ -41,13 +42,12 @@ namespace validation
{
/** [ReshapeLayer fixture] **/
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ReshapeLayerValidationFixture : public framework::Fixture
+class ReshapeLayerGenericValidationFixture : public framework::Fixture
{
public:
- template <typename...>
- void setup(TensorShape input_shape, TensorShape output_shape, DataType data_type)
+ void setup(TensorShape input_shape, TensorShape output_shape, DataType data_type, bool add_x_padding = false)
{
- _target = compute_target(input_shape, output_shape, data_type);
+ _target = compute_target(input_shape, output_shape, data_type, add_x_padding);
_reference = compute_reference(input_shape, output_shape, data_type);
}
@@ -58,7 +58,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type, bool add_x_padding = false)
{
// Check if indeed the input shape can be reshape to the output one
ARM_COMPUTE_ASSERT(input_shape.total_size() == output_shape.total_size());
@@ -75,6 +75,12 @@ protected:
ARM_COMPUTE_ASSERT(src.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ if(add_x_padding)
+ {
+ // Add random padding in x dimension
+ add_padding_x({ &src, &dst });
+ }
+
// Allocate tensors
src.allocator()->allocate();
dst.allocator()->allocate();
@@ -105,8 +111,27 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ReshapeLayerValidationFixture : public ReshapeLayerGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ void setup(TensorShape input_shape, TensorShape output_shape, DataType data_type)
+ {
+ ReshapeLayerGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, output_shape, data_type);
+ }
+};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ReshapeLayerPaddedValidationFixture : public ReshapeLayerGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ void setup(TensorShape input_shape, TensorShape output_shape, DataType data_type)
+ {
+ ReshapeLayerGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, output_shape, data_type, true /* add_x_padding */);
+ }
+};
/** [ReshapeLayer fixture] **/
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_RESHAPE_LAYER_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_RESHAPELAYERFIXTURE_H