diff options
Diffstat (limited to 'tests/validation/fixtures/StackLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/StackLayerFixture.h | 45 |
1 files changed, 34 insertions, 11 deletions
diff --git a/tests/validation/fixtures/StackLayerFixture.h b/tests/validation/fixtures/StackLayerFixture.h index 7bf63a3ebc..7dd8fe47dc 100644 --- a/tests/validation/fixtures/StackLayerFixture.h +++ b/tests/validation/fixtures/StackLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE -#define ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE +#ifndef ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorShape.h" @@ -52,10 +52,9 @@ template <typename TensorType, typename AbstractTensorType, typename AccessorTyp class StackLayerValidationFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors) { - _target = compute_target(shape_src, axis, data_type, num_tensors); + _target = compute_target(shape_src, axis, data_type, num_tensors, false /* add_x_padding */); _reference = compute_reference(shape_src, axis, data_type, num_tensors); } @@ -66,7 +65,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors) + TensorType compute_target(TensorShape shape_src, int axis, DataType data_type, int num_tensors, bool add_x_padding) { std::vector<TensorType> tensors(num_tensors); std::vector<AbstractTensorType *> src(num_tensors); @@ -76,7 +75,7 @@ protected: { tensors[i] = create_tensor<TensorType>(shape_src, data_type); src[i] = &(tensors[i]); - ARM_COMPUTE_EXPECT(tensors[i].info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(tensors[i].info()->is_resizable()); } // Create tensors @@ -91,18 +90,28 @@ protected: // Allocate and fill the input tensors for(int i = 0; i < num_tensors; ++i) { - ARM_COMPUTE_EXPECT(tensors[i].info()->is_resizable(), framework::LogLevel::ERRORS); + if(add_x_padding) + { + add_padding_x({&tensors[i]}, DataLayout::NHWC); + } + + ARM_COMPUTE_ASSERT(tensors[i].info()->is_resizable()); tensors[i].allocator()->allocate(); - ARM_COMPUTE_EXPECT(!tensors[i].info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!tensors[i].info()->is_resizable()); // Fill input tensor fill(AccessorType(tensors[i]), i); } + if(add_x_padding) + { + add_padding_x({&dst}, DataLayout::NHWC); + } + // Allocate output tensor dst.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Compute stack function stack.run(); @@ -132,7 +141,21 @@ protected: TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AbstractTensorType, typename AccessorType, typename FunctionType, typename T> +class StackLayerWithPaddingValidationFixture : + public StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T> +{ +public: + using Parent = StackLayerValidationFixture<TensorType, AbstractTensorType, AccessorType, FunctionType, T>; + + void setup(TensorShape shape_src, int axis, DataType data_type, int num_tensors) + { + Parent::_target = Parent::compute_target(shape_src, axis, data_type, num_tensors, true /* add_x_padding */); + Parent::_reference = Parent::compute_reference(shape_src, axis, data_type, num_tensors); + } +}; } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ARM_COMPUTE_TEST_STACK_LAYER_FIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_STACKLAYERFIXTURE_H |