diff options
Diffstat (limited to 'tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h')
-rw-r--r-- | tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h b/tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h index 3a5ab7c5e1..d88029f93e 100644 --- a/tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h +++ b/tests/validation/fixtures/GEMMReshapeLHSMatrixFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -46,11 +46,10 @@ namespace validation { using namespace arm_compute::misc::shape_calculator; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool reinterpret_input_as_3d = false> +template <typename TensorType, typename AccessorType, typename OperatorType, typename T, bool reinterpret_input_as_3d = false> class GEMMReshapeLHSMatrixValidationFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape_in, unsigned int batch_size, DataType data_type, unsigned int m0, unsigned int k0, unsigned int v0, bool interleave, bool transpose) { GEMMLHSMatrixInfo lhs_info; @@ -86,23 +85,26 @@ protected: // The output tensor will be auto-initialized within the function // Create and configure function - FunctionType gemm_lhs_reshape; - gemm_lhs_reshape.configure(&src, &dst, lhs_info, reinterpret_input_as_3d); + OperatorType gemm_lhs_reshape; + gemm_lhs_reshape.configure(src.info(), dst.info(), lhs_info, reinterpret_input_as_3d); - ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(src.info()->is_resizable()); + + add_padding_x({ &src, &dst }); // Allocate tensors src.allocator()->allocate(); dst.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // Fill tensors fill(AccessorType(src)); // Compute GEMM LHS matrix reshape function - gemm_lhs_reshape.run(); + ITensorPack tensors = { { ACL_SRC, &src }, { ACL_DST, &dst } }; + gemm_lhs_reshape.run(tensors); return dst; } |