diff options
author | Kurtis Charnock <kurtis.charnock@arm.com> | 2019-11-29 11:42:30 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-01-13 17:43:02 +0000 |
commit | ec00da149b7884114cd6c43cc9c1cff62ddaa710 (patch) | |
tree | ef02ea4d3c5b45178606f40cbea6af6b33005d6f /tests/datasets/SplitDataset.h | |
parent | 6f314db14f0fd242d53f3a9f780158169259b31b (diff) | |
download | ComputeLibrary-ec00da149b7884114cd6c43cc9c1cff62ddaa710.tar.gz |
COMPMID-2728: Add support for split sizes in CLSplit
Signed-off-by: Kurtis Charnock <kurtis.charnock@arm.com>
Change-Id: I69ea9e812478904c3e10379bb5943d534c45f942
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/214132
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2432
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'tests/datasets/SplitDataset.h')
-rw-r--r-- | tests/datasets/SplitDataset.h | 86 |
1 files changed, 85 insertions, 1 deletions
diff --git a/tests/datasets/SplitDataset.h b/tests/datasets/SplitDataset.h index b38252a489..3d4c289ba7 100644 --- a/tests/datasets/SplitDataset.h +++ b/tests/datasets/SplitDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -128,6 +128,90 @@ public: add_config(TensorShape(128U, 64U, 32U, 4U), 3U, 4U); } }; + +class SplitShapesDataset +{ +public: + using type = std::tuple<TensorShape, unsigned int, std::vector<TensorShape>>; + + struct iterator + { + iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, + std::vector<unsigned int>::const_iterator axis_values_it, + std::vector<std::vector<TensorShape>>::const_iterator split_shapes_values_it) + : _tensor_shapes_it{ std::move(tensor_shapes_it) }, + _axis_values_it{ std::move(axis_values_it) }, + _split_shapes_values_it{ std::move(split_shapes_values_it) } + { + } + + std::string description() const + { + std::stringstream description; + description << "Shape=" << *_tensor_shapes_it << ":"; + description << "Axis=" << *_axis_values_it << ":"; + description << "Split shapes=" << *_split_shapes_values_it << ":"; + return description.str(); + } + + SplitShapesDataset::type operator*() const + { + return std::make_tuple(*_tensor_shapes_it, *_axis_values_it, *_split_shapes_values_it); + } + + iterator &operator++() + { + ++_tensor_shapes_it; + ++_axis_values_it; + ++_split_shapes_values_it; + return *this; + } + + private: + std::vector<TensorShape>::const_iterator _tensor_shapes_it; + std::vector<unsigned int>::const_iterator _axis_values_it; + std::vector<std::vector<TensorShape>>::const_iterator _split_shapes_values_it; + }; + + iterator begin() const + { + return iterator(_tensor_shapes.begin(), _axis_values.begin(), _split_shapes_values.begin()); + } + + int size() const + { + return std::min(_tensor_shapes.size(), std::min(_axis_values.size(), _split_shapes_values.size())); + } + + void add_config(TensorShape shape, unsigned int axis, std::vector<TensorShape> split_shapes) + { + _tensor_shapes.emplace_back(std::move(shape)); + _axis_values.emplace_back(axis); + _split_shapes_values.emplace_back(split_shapes); + } + +protected: + SplitShapesDataset() = default; + SplitShapesDataset(SplitShapesDataset &&) = default; + +private: + std::vector<TensorShape> _tensor_shapes{}; + std::vector<unsigned int> _axis_values{}; + std::vector<std::vector<TensorShape>> _split_shapes_values{}; +}; + +class SmallSplitShapesDataset final : public SplitShapesDataset +{ +public: + SmallSplitShapesDataset() + { + add_config(TensorShape(27U, 3U, 16U, 2U), 2U, std::vector<TensorShape> { TensorShape(27U, 3U, 4U, 2U), + TensorShape(27U, 3U, 4U, 2U), + TensorShape(27U, 3U, 8U, 2U) + }); + } +}; + } // namespace datasets } // namespace test } // namespace arm_compute |