aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/test/GatherNdTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/test/GatherNdTest.cpp')
-rw-r--r--delegate/src/test/GatherNdTest.cpp24
1 files changed, 12 insertions, 12 deletions
diff --git a/delegate/src/test/GatherNdTest.cpp b/delegate/src/test/GatherNdTest.cpp
index b56a931d27..2b4fd4207e 100644
--- a/delegate/src/test/GatherNdTest.cpp
+++ b/delegate/src/test/GatherNdTest.cpp
@@ -19,13 +19,13 @@ namespace armnnDelegate
void GatherNdUint8Test(std::vector<armnn::BackendId>& backends)
{
- std::vector<int32_t> paramsShape{8};
- std::vector<int32_t> indicesShape{3,1};
- std::vector<int32_t> expectedOutputShape{3};
+ std::vector<int32_t> paramsShape{ 5, 2 };
+ std::vector<int32_t> indicesShape{ 3, 1 };
+ std::vector<int32_t> expectedOutputShape{ 3, 2 };
- std::vector<uint8_t> paramsValues{1, 2, 3, 4, 5, 6, 7, 8};
- std::vector<int32_t> indicesValues{7, 6, 5};
- std::vector<uint8_t> expectedOutputValues{8, 7, 6};
+ std::vector<uint8_t> paramsValues{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
+ std::vector<int32_t> indicesValues{ 1, 0, 4 };
+ std::vector<uint8_t> expectedOutputValues{ 3, 4, 1, 2, 9, 10 };
GatherNdTest<uint8_t>(::tflite::TensorType_UINT8,
backends,
@@ -39,13 +39,13 @@ void GatherNdUint8Test(std::vector<armnn::BackendId>& backends)
void GatherNdFp32Test(std::vector<armnn::BackendId>& backends)
{
- std::vector<int32_t> paramsShape{8};
- std::vector<int32_t> indicesShape{3,1};
- std::vector<int32_t> expectedOutputShape{3};
+ std::vector<int32_t> paramsShape{ 5, 2 };
+ std::vector<int32_t> indicesShape{ 3, 1 };
+ std::vector<int32_t> expectedOutputShape{ 3, 2 };
- std::vector<float> paramsValues{1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f};
- std::vector<int32_t> indicesValues{7, 6, 5};
- std::vector<float> expectedOutputValues{8.8f, 7.7f, 6.6f};
+ std::vector<float> paramsValues{ 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.10f };
+ std::vector<int32_t> indicesValues{ 1, 0, 4 };
+ std::vector<float> expectedOutputValues{ 3.3f, 4.4f, 1.1f, 2.2f, 9.9f, 10.10f };
GatherNdTest<float>(::tflite::TensorType_FLOAT32,
backends,