aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/ScatterLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/ScatterLayer.cpp')
-rw-r--r--tests/validation/CL/ScatterLayer.cpp70
1 files changed, 68 insertions, 2 deletions
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 66b71ef650..56338f489f 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -24,8 +24,13 @@
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/functions/CLScatter.h"
#include "tests/validation/fixtures/ScatterLayerFixture.h"
+#include "tests/datasets/ScatterDataset.h"
#include "tests/CL/CLAccessor.h"
+#include "arm_compute/function_info/ScatterInfo.h"
+#include "tests/framework/Asserts.h"
#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
namespace arm_compute
{
@@ -37,13 +42,74 @@ namespace validation
template <typename T>
using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
+using framework::dataset::make;
+
TEST_SUITE(CL)
-TEST_SUITE(ScatterLayer)
+TEST_SUITE(Scatter)
+DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
+ make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
+ TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
+ TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
+ }),
+ make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ }),
+ make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
+ TensorInfo(TensorShape(15U), 1, DataType::U32),
+ TensorInfo(TensorShape(2U), 1, DataType::U32),
+ TensorInfo(TensorShape(271U), 1, DataType::U32),
+ TensorInfo(TensorShape(271U), 1, DataType::U32),
+ TensorInfo(TensorShape(2U), 1 , DataType::S32)
+ }),
+ make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U), 1, DataType::F32),
+ TensorInfo(TensorShape(12U), 1, DataType::F32)
+ }),
+ make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
+ }),
+ make("Expected", { false, true, true, false, false, false })),
+ input_info, updates_info, indices_info, output_info, scatter_info, expected)
+{
+ // TODO: Enable validation tests.
+ ARM_COMPUTE_UNUSED(input_info);
+ ARM_COMPUTE_UNUSED(updates_info);
+ ARM_COMPUTE_UNUSED(indices_info);
+ ARM_COMPUTE_UNUSED(output_info);
+ ARM_COMPUTE_UNUSED(scatter_info);
+ ARM_COMPUTE_UNUSED(expected);
+}
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
+ make("ZeroInit", {false})))
+{
+ // TODO: Add validate() here.
+}
+
+// With this test, src should be passed as nullptr.
+FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Add}),
+ make("ZeroInit", {true})))
+{
+ // TODO: Add validate() here
+}
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
-TEST_SUITE_END() // ScatterLayer
+TEST_SUITE_END() // Scatter
TEST_SUITE_END() // CL
} // namespace validation
} // namespace test