aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/SplitDataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/SplitDataset.h')
-rw-r--r--tests/datasets/SplitDataset.h86
1 files changed, 85 insertions, 1 deletions
diff --git a/tests/datasets/SplitDataset.h b/tests/datasets/SplitDataset.h
index b38252a489..3d4c289ba7 100644
--- a/tests/datasets/SplitDataset.h
+++ b/tests/datasets/SplitDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -128,6 +128,90 @@ public:
add_config(TensorShape(128U, 64U, 32U, 4U), 3U, 4U);
}
};
+
+class SplitShapesDataset
+{
+public:
+ using type = std::tuple<TensorShape, unsigned int, std::vector<TensorShape>>;
+
+ struct iterator
+ {
+ iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
+ std::vector<unsigned int>::const_iterator axis_values_it,
+ std::vector<std::vector<TensorShape>>::const_iterator split_shapes_values_it)
+ : _tensor_shapes_it{ std::move(tensor_shapes_it) },
+ _axis_values_it{ std::move(axis_values_it) },
+ _split_shapes_values_it{ std::move(split_shapes_values_it) }
+ {
+ }
+
+ std::string description() const
+ {
+ std::stringstream description;
+ description << "Shape=" << *_tensor_shapes_it << ":";
+ description << "Axis=" << *_axis_values_it << ":";
+ description << "Split shapes=" << *_split_shapes_values_it << ":";
+ return description.str();
+ }
+
+ SplitShapesDataset::type operator*() const
+ {
+ return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_split_shapes_values_it);
+ }
+
+ iterator &operator++()
+ {
+ ++_tensor_shapes_it;
+ ++_axis_values_it;
+ ++_split_shapes_values_it;
+ return *this;
+ }
+
+ private:
+ std::vector<TensorShape>::const_iterator _tensor_shapes_it;
+ std::vector<unsigned int>::const_iterator _axis_values_it;
+ std::vector<std::vector<TensorShape>>::const_iterator _split_shapes_values_it;
+ };
+
+ iterator begin() const
+ {
+ return iterator(_tensor_shapes.begin(), _axis_values.begin(), _split_shapes_values.begin());
+ }
+
+ int size() const
+ {
+ return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _split_shapes_values.size()));
+ }
+
+ void add_config(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes)
+ {
+ _tensor_shapes.emplace_back(std::move(shape));
+ _axis_values.emplace_back(axis);
+ _split_shapes_values.emplace_back(split_shapes);
+ }
+
+protected:
+ SplitShapesDataset() = default;
+ SplitShapesDataset(SplitShapesDataset &&) = default;
+
+private:
+ std::vector<TensorShape> _tensor_shapes{};
+ std::vector<unsigned int> _axis_values{};
+ std::vector<std::vector<TensorShape>> _split_shapes_values{};
+};
+
+class SmallSplitShapesDataset final : public SplitShapesDataset
+{
+public:
+ SmallSplitShapesDataset()
+ {
+ add_config(TensorShape(27U, 3U, 16U, 2U), 2U, std::vector<TensorShape> { TensorShape(27U, 3U, 4U, 2U),
+ TensorShape(27U, 3U, 4U, 2U),
+ TensorShape(27U, 3U, 8U, 2U)
+ });
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute