diff options
Diffstat (limited to 'src/backends/reference/test')
-rw-r--r-- | src/backends/reference/test/RefEndToEndTests.cpp | 31 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 12 |
2 files changed, 43 insertions, 0 deletions
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index e1c2e2f2a7..2ed5ad812c 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -19,6 +19,7 @@ #include <backendsCommon/test/FillEndToEndTestImpl.hpp> #include <backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp> #include <backendsCommon/test/GatherEndToEndTestImpl.hpp> +#include <backendsCommon/test/GatherNdEndToEndTestImpl.hpp> #include <backendsCommon/test/InstanceNormalizationEndToEndTestImpl.hpp> #include <backendsCommon/test/LogSoftmaxEndToEndTestImpl.hpp> #include <backendsCommon/test/PreluEndToEndTestImpl.hpp> @@ -720,6 +721,36 @@ TEST_CASE("RefGatherMultiDimInt16Test") GatherMultiDimEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); } +TEST_CASE("RefGatherNdFloatTest") +{ + GatherNdEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +TEST_CASE("RefGatherNdUint8Test") +{ + GatherNdEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends); +} + +TEST_CASE("RefGatherNdInt16Test") +{ + GatherNdEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimFloatTest") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::Float32>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimUint8Test") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::QAsymmU8>(defaultBackends); +} + +TEST_CASE("RefGatherNdMultiDimInt16Test") +{ + GatherNdMultiDimEndToEnd<armnn::DataType::QSymmS16>(defaultBackends); +} + // DepthToSpace TEST_CASE("DephtToSpaceEndToEndNchwFloat32") { diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 9dca621e13..496b11db91 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -2155,6 +2155,18 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesUint8, GatherMu ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt16, GatherMultiDimParamsMultiDimIndicesInt16Test) ARMNN_AUTO_TEST_CASE_WITH_THF(GatherMultiDimParamsMultiDimIndicesInt32, GatherMultiDimParamsMultiDimIndicesInt32Test) + +// GatherNd +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dFloat32, SimpleGatherNd2dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dFloat32, SimpleGatherNd3dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dFloat32, SimpleGatherNd4dTest<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dInt8, SimpleGatherNd2dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dInt8, SimpleGatherNd3dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dInt8, SimpleGatherNd4dTest<DataType::QAsymmS8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd2dInt32, SimpleGatherNd2dTest<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd3dInt32, SimpleGatherNd3dTest<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(GatherNd4dInt32, SimpleGatherNd4dTest<DataType::Signed32>) + // Abs ARMNN_AUTO_TEST_CASE_WITH_THF(Abs2d, Abs2dTest<DataType::Float32>) ARMNN_AUTO_TEST_CASE_WITH_THF(Abs3d, Abs3dTest<DataType::Float32>) |