// // Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "GatherNdTestHelper.hpp" #include namespace armnnDelegate { // Gather_Nd Operator void GatherNdUint8Test() { std::vector paramsShape{ 5, 2 }; std::vector indicesShape{ 3, 1 }; std::vector expectedOutputShape{ 3, 2 }; std::vector paramsValues{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; std::vector indicesValues{ 1, 0, 4 }; std::vector expectedOutputValues{ 3, 4, 1, 2, 9, 10 }; GatherNdTest(::tflite::TensorType_UINT8, paramsShape, indicesShape, expectedOutputShape, paramsValues, indicesValues, expectedOutputValues); } void GatherNdFp32Test() { std::vector paramsShape{ 5, 2 }; std::vector indicesShape{ 3, 1 }; std::vector expectedOutputShape{ 3, 2 }; std::vector paramsValues{ 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.10f }; std::vector indicesValues{ 1, 0, 4 }; std::vector expectedOutputValues{ 3.3f, 4.4f, 1.1f, 2.2f, 9.9f, 10.10f }; GatherNdTest(::tflite::TensorType_FLOAT32, paramsShape, indicesShape, expectedOutputShape, paramsValues, indicesValues, expectedOutputValues); } // Gather_Nd Test Suite TEST_SUITE("Gather_NdTests") { TEST_CASE ("Gather_Nd_Uint8_Test") { GatherNdUint8Test(); } TEST_CASE ("Gather_Nd_Fp32_Test") { GatherNdFp32Test(); } } // End of Gather_Nd Test Suite } // namespace armnnDelegate