diff options
Diffstat (limited to 'tests/datasets')
-rw-r--r-- | tests/datasets/SplitDataset.h | 86 |
1 files changed, 85 insertions, 1 deletions
diff --git a/tests/datasets/SplitDataset.h b/tests/datasets/SplitDataset.h index b38252a489..3d4c289ba7 100644 --- a/tests/datasets/SplitDataset.h +++ b/tests/datasets/SplitDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -128,6 +128,90 @@ public: add_config(TensorShape(128U, 64U, 32U, 4U), 3U, 4U); } }; + +class SplitShapesDataset +{ +public: + using type = std::tuple<TensorShape, unsigned int, std::vector<TensorShape>>; + + struct iterator + { + iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, + std::vector<unsigned int>::const_iterator axis_values_it, + std::vector<std::vector<TensorShape>>::const_iterator split_shapes_values_it) + : _tensor_shapes_it{ std::move(tensor_shapes_it) }, + _axis_values_it{ std::move(axis_values_it) }, + _split_shapes_values_it{ std::move(split_shapes_values_it) } + { + } + + std::string description() const + { + std::stringstream description; + description << "Shape=" << *_tensor_shapes_it << ":"; + description << "Axis=" << *_axis_values_it << ":"; + description << "Split shapes=" << *_split_shapes_values_it << ":"; + return description.str(); + } + + SplitShapesDataset::type operator*() const + { + return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_split_shapes_values_it); + } + + iterator &operator++() + { + ++_tensor_shapes_it; + ++_axis_values_it; + ++_split_shapes_values_it; + return *this; + } + + private: + std::vector<TensorShape>::const_iterator _tensor_shapes_it; + std::vector<unsigned int>::const_iterator _axis_values_it; + std::vector<std::vector<TensorShape>>::const_iterator _split_shapes_values_it; + }; + + iterator begin() const + { + return iterator(_tensor_shapes.begin(), _axis_values.begin(), _split_shapes_values.begin()); + } + + int size() const + { + return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _split_shapes_values.size())); + } + + void add_config(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes) + { + _tensor_shapes.emplace_back(std::move(shape)); + _axis_values.emplace_back(axis); + _split_shapes_values.emplace_back(split_shapes); + } + +protected: + SplitShapesDataset() = default; + SplitShapesDataset(SplitShapesDataset &&) = default; + +private: + std::vector<TensorShape> _tensor_shapes{}; + std::vector<unsigned int> _axis_values{}; + std::vector<std::vector<TensorShape>> _split_shapes_values{}; +}; + +class SmallSplitShapesDataset final : public SplitShapesDataset +{ +public: + SmallSplitShapesDataset() + { + add_config(TensorShape(27U, 3U, 16U, 2U), 2U, std::vector<TensorShape> { TensorShape(27U, 3U, 4U, 2U), + TensorShape(27U, 3U, 4U, 2U), + TensorShape(27U, 3U, 8U, 2U) + }); + } +}; + } // namespace datasets } // namespace test } // namespace arm_compute |