diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-08-24 11:25:32 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | c1a72451273ec019e3e74c4b53ea847afe8ddf7c (patch) | |
tree | b4bd62a7ccd22a2c60070d7fa23ceba794dcac5c /tests/datasets | |
parent | 6a8d3b6db13042a859972c33cf40cfeb6d7cfcda (diff) | |
download | ComputeLibrary-c1a72451273ec019e3e74c4b53ea847afe8ddf7c.tar.gz |
COMPMID-1332: Implement Slice for CL
Change-Id: I0dbc4fd7f640d31daa1970eb3da0e941cb771f2b
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/146145
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/datasets')
-rw-r--r-- | tests/datasets/SliceOperationsDataset.h (renamed from tests/datasets/StridedSliceDataset.h) | 113 |
1 files changed, 109 insertions, 4 deletions
diff --git a/tests/datasets/StridedSliceDataset.h b/tests/datasets/SliceOperationsDataset.h index 00f19920b8..b6df4040fd 100644 --- a/tests/datasets/StridedSliceDataset.h +++ b/tests/datasets/SliceOperationsDataset.h @@ -34,6 +34,77 @@ namespace test { namespace datasets { +class SliceDataset +{ +public: + using type = std::tuple<TensorShape, Coordinates, Coordinates>; + + struct iterator + { + iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it, + std::vector<Coordinates>::const_iterator starts_values_it, + std::vector<Coordinates>::const_iterator ends_values_it) + : _tensor_shapes_it{ std::move(tensor_shapes_it) }, + _starts_values_it{ std::move(starts_values_it) }, + _ends_values_it{ std::move(ends_values_it) } + { + } + + std::string description() const + { + std::stringstream description; + description << "Shape=" << *_tensor_shapes_it << ":"; + description << "Starts=" << *_starts_values_it << ":"; + description << "Ends=" << *_ends_values_it << ":"; + return description.str(); + } + + SliceDataset::type operator*() const + { + return std::make_tuple(*_tensor_shapes_it, *_starts_values_it, *_ends_values_it); + } + + iterator &operator++() + { + ++_tensor_shapes_it; + ++_starts_values_it; + ++_ends_values_it; + return *this; + } + + private: + std::vector<TensorShape>::const_iterator _tensor_shapes_it; + std::vector<Coordinates>::const_iterator _starts_values_it; + std::vector<Coordinates>::const_iterator _ends_values_it; + }; + + iterator begin() const + { + return iterator(_tensor_shapes.begin(), _starts_values.begin(), _ends_values.begin()); + } + + int size() const + { + return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), _ends_values.size())); + } + + void add_config(TensorShape shape, Coordinates starts, Coordinates ends) + { + _tensor_shapes.emplace_back(std::move(shape)); + _starts_values.emplace_back(std::move(starts)); + _ends_values.emplace_back(std::move(ends)); + } + +protected: + SliceDataset() = default; + SliceDataset(SliceDataset &&) = default; + +private: + std::vector<TensorShape> _tensor_shapes{}; + std::vector<Coordinates> _starts_values{}; + std::vector<Coordinates> _ends_values{}; +}; + class StridedSliceDataset { public: @@ -140,6 +211,41 @@ private: std::vector<int32_t> _shrink_mask_values{}; }; +class SmallSliceDataset final : public SliceDataset +{ +public: + SmallSliceDataset() + { + // 1D + add_config(TensorShape(15U), Coordinates(4), Coordinates(9)); + add_config(TensorShape(15U), Coordinates(0), Coordinates(-1)); + // 2D + add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1)); + add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1)); + // 3D + add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4)); + add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4)); + // 4D + add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5)); + } +}; + +class LargeSliceDataset final : public SliceDataset +{ +public: + LargeSliceDataset() + { + // 1D + add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100)); + // 2D + add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -1)); + // 3D + add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, 2), Coordinates(368, -1, 4)); + // 4D + add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, 17, 5)); + } +}; + class SmallStridedSliceDataset final : public StridedSliceDataset { public: @@ -167,14 +273,13 @@ public: // 1D add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100), BiStrides(20)); // 2D - add_config(TensorShape(372U, 68U), Coordinates(128U, 7U), Coordinates(368U, -30), BiStrides(10, 7)); + add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -30), BiStrides(10, 7)); // 3D - add_config(TensorShape(372U, 68U, 12U), Coordinates(128U, 7U, -1), Coordinates(368U, -30, -5), BiStrides(14, 7, -2)); + add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, -1), Coordinates(368, -30, -5), BiStrides(14, 7, -2)); // 4D - add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128U, 7U, 2U), Coordinates(368U, -30, 5U), BiStrides(20, 7, 2), 1, 1); + add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, -30, 5), BiStrides(20, 7, 2), 1, 1); } }; - } // namespace datasets } // namespace test } // namespace arm_compute |