diff options
Diffstat (limited to 'tests/validation/fixtures/ReorderFixture.h')
-rw-r--r-- | tests/validation/fixtures/ReorderFixture.h | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/tests/validation/fixtures/ReorderFixture.h b/tests/validation/fixtures/ReorderFixture.h index 36e62696bc..8e28484c48 100644 --- a/tests/validation/fixtures/ReorderFixture.h +++ b/tests/validation/fixtures/ReorderFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE -#define ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE +#ifndef ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE_H +#define ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE_H #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" @@ -32,6 +32,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/reference/Reorder.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" namespace arm_compute { @@ -44,10 +45,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ReorderValidationFixture : public framework::Fixture { public: + void check_hardware_supports(WeightFormat output_wf){ + if(!Scheduler::get().cpu_info().has_sve() && output_wf!=WeightFormat::OHWIo4){ + _hardware_supports = false; + } + if (Scheduler::get().cpu_info().has_sve() && arm_gemm::utils::get_vector_length<float>() != 8 && output_wf==WeightFormat::OHWIo8) + { + _hardware_supports = false; + } + } + void setup(TensorShape input_shape, TensorShape output_shape, WeightFormat input_wf, WeightFormat output_wf, DataType data_type) { - _target = compute_target(input_shape, output_shape, input_wf, output_wf, data_type); - _reference = compute_reference(input_shape, output_shape, output_wf, data_type); + check_hardware_supports(output_wf); + if (_hardware_supports){ + _target = compute_target(input_shape, output_shape, input_wf, output_wf, data_type); + _reference = compute_reference(input_shape, output_shape, output_wf, data_type); + } } protected: @@ -98,6 +112,7 @@ public: return reference::reorder_layer<T>(src, output_shape, output_wf); } + bool _hardware_supports = true; TensorType _target{}; SimpleTensor<T> _reference{}; }; @@ -105,4 +120,4 @@ public: } // namespace validation } // namespace test } // namespace arm_compute -#endif /* ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE */ +#endif // ACL_TESTS_VALIDATION_FIXTURES_REORDERFIXTURE_H |