aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/StackLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/StackLayer.cpp')
-rw-r--r--tests/validation/CL/StackLayer.cpp33
1 files changed, 32 insertions, 1 deletions
diff --git a/tests/validation/CL/StackLayer.cpp b/tests/validation/CL/StackLayer.cpp
index 089911272a..fa2e4acc11 100644
--- a/tests/validation/CL/StackLayer.cpp
+++ b/tests/validation/CL/StackLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -117,6 +117,37 @@ using namespace arm_compute::misc::shape_calculator;
TEST_SUITE(CL)
TEST_SUITE(StackLayer)
+
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+ framework::dataset::make("InputInfo",
+{
+ std::vector<TensorInfo>{ TensorInfo(TensorShape(9U, 8U), 1, DataType::U8) },
+ std::vector<TensorInfo>{ TensorInfo(TensorShape(1U, 2U), 1, DataType::U8) , TensorInfo(TensorShape(1U, 2U), 1, DataType::U8), TensorInfo(TensorShape(1U, 2U), 1, DataType::U8)},
+ std::vector<TensorInfo>{ TensorInfo(TensorShape(2U, 3U), 1, DataType::S32) },
+ std::vector<TensorInfo>{ TensorInfo(TensorShape(7U, 5U, 3U, 8U, 2U), 1, DataType::S32), TensorInfo(TensorShape(7U, 5U, 3U, 8U, 2U), 1, DataType::S32)},
+ std::vector<TensorInfo>{ TensorInfo(TensorShape(9U, 8U), 1, DataType::S32) },
+}),
+framework::dataset::make("OutputInfo",
+{
+ TensorInfo(TensorShape(1U, 9U, 8U), 1, DataType::U8), // Passes, stack 1 tensor on x axis
+ TensorInfo(TensorShape(1U, 3U, 2U), 1, DataType::U8), // Passes, stack 3 tensors on y axis
+ TensorInfo(TensorShape(1U, 2U, 3U), 1, DataType::S32), // fails axis < (- input's rank)
+ TensorInfo(TensorShape(3U, 7U, 5U), 1, DataType::S32), // fails, input dimensions > 4
+ TensorInfo(TensorShape(1U, 2U, 3U), 1, DataType::U8), // fails mismatching data types
+})),
+framework::dataset::make("Axis", { -3, 1, -4, -3, 1 })),
+framework::dataset::make("Expected", { true, true, false, false, false })),
+input_info, output_info, axis, expected)
+{
+ std::vector<TensorInfo> ti(input_info);
+ std::vector<ITensorInfo *> vec(input_info.size());
+ for(size_t j = 0; j < vec.size(); ++j)
+ {
+ vec[j] = &ti[j];
+ }
+ ARM_COMPUTE_EXPECT(bool(CLStackLayer::validate(vec, axis, &output_info)) == expected, framework::LogLevel::ERRORS);
+}
+
TEST_SUITE(Shapes1D)
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(shapes_1d_small,