aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-08-24 11:25:32 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitc1a72451273ec019e3e74c4b53ea847afe8ddf7c (patch)
treeb4bd62a7ccd22a2c60070d7fa23ceba794dcac5c /tests/datasets
parent6a8d3b6db13042a859972c33cf40cfeb6d7cfcda (diff)
downloadComputeLibrary-c1a72451273ec019e3e74c4b53ea847afe8ddf7c.tar.gz
COMPMID-1332: Implement Slice for CL
Change-Id: I0dbc4fd7f640d31daa1970eb3da0e941cb771f2b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/146145 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/datasets')
-rw-r--r--tests/datasets/SliceOperationsDataset.h (renamed from tests/datasets/StridedSliceDataset.h)113
1 files changed, 109 insertions, 4 deletions
diff --git a/tests/datasets/StridedSliceDataset.h b/tests/datasets/SliceOperationsDataset.h
index 00f19920b8..b6df4040fd 100644
--- a/tests/datasets/StridedSliceDataset.h
+++ b/tests/datasets/SliceOperationsDataset.h
@@ -34,6 +34,77 @@ namespace test
{
namespace datasets
{
+class SliceDataset
+{
+public:
+ using type = std::tuple<TensorShape, Coordinates, Coordinates>;
+
+ struct iterator
+ {
+ iterator(std::vector<TensorShape>::const_iterator tensor_shapes_it,
+ std::vector<Coordinates>::const_iterator starts_values_it,
+ std::vector<Coordinates>::const_iterator ends_values_it)
+ : _tensor_shapes_it{ std::move(tensor_shapes_it) },
+ _starts_values_it{ std::move(starts_values_it) },
+ _ends_values_it{ std::move(ends_values_it) }
+ {
+ }
+
+ std::string description() const
+ {
+ std::stringstream description;
+ description << "Shape=" << *_tensor_shapes_it << ":";
+ description << "Starts=" << *_starts_values_it << ":";
+ description << "Ends=" << *_ends_values_it << ":";
+ return description.str();
+ }
+
+ SliceDataset::type operator*() const
+ {
+ return std::make_tuple(*_tensor_shapes_it, *_starts_values_it, *_ends_values_it);
+ }
+
+ iterator &operator++()
+ {
+ ++_tensor_shapes_it;
+ ++_starts_values_it;
+ ++_ends_values_it;
+ return *this;
+ }
+
+ private:
+ std::vector<TensorShape>::const_iterator _tensor_shapes_it;
+ std::vector<Coordinates>::const_iterator _starts_values_it;
+ std::vector<Coordinates>::const_iterator _ends_values_it;
+ };
+
+ iterator begin() const
+ {
+ return iterator(_tensor_shapes.begin(), _starts_values.begin(), _ends_values.begin());
+ }
+
+ int size() const
+ {
+ return std::min(_tensor_shapes.size(), std::min(_starts_values.size(), _ends_values.size()));
+ }
+
+ void add_config(TensorShape shape, Coordinates starts, Coordinates ends)
+ {
+ _tensor_shapes.emplace_back(std::move(shape));
+ _starts_values.emplace_back(std::move(starts));
+ _ends_values.emplace_back(std::move(ends));
+ }
+
+protected:
+ SliceDataset() = default;
+ SliceDataset(SliceDataset &&) = default;
+
+private:
+ std::vector<TensorShape> _tensor_shapes{};
+ std::vector<Coordinates> _starts_values{};
+ std::vector<Coordinates> _ends_values{};
+};
+
class StridedSliceDataset
{
public:
@@ -140,6 +211,41 @@ private:
std::vector<int32_t> _shrink_mask_values{};
};
+class SmallSliceDataset final : public SliceDataset
+{
+public:
+ SmallSliceDataset()
+ {
+ // 1D
+ add_config(TensorShape(15U), Coordinates(4), Coordinates(9));
+ add_config(TensorShape(15U), Coordinates(0), Coordinates(-1));
+ // 2D
+ add_config(TensorShape(15U, 16U), Coordinates(0, 1), Coordinates(5, -1));
+ add_config(TensorShape(15U, 16U), Coordinates(4, 1), Coordinates(12, -1));
+ // 3D
+ add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
+ add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4));
+ // 4D
+ add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5));
+ }
+};
+
+class LargeSliceDataset final : public SliceDataset
+{
+public:
+ LargeSliceDataset()
+ {
+ // 1D
+ add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100));
+ // 2D
+ add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -1));
+ // 3D
+ add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, 2), Coordinates(368, -1, 4));
+ // 4D
+ add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, 17, 5));
+ }
+};
+
class SmallStridedSliceDataset final : public StridedSliceDataset
{
public:
@@ -167,14 +273,13 @@ public:
// 1D
add_config(TensorShape(1025U), Coordinates(128), Coordinates(-100), BiStrides(20));
// 2D
- add_config(TensorShape(372U, 68U), Coordinates(128U, 7U), Coordinates(368U, -30), BiStrides(10, 7));
+ add_config(TensorShape(372U, 68U), Coordinates(128, 7), Coordinates(368, -30), BiStrides(10, 7));
// 3D
- add_config(TensorShape(372U, 68U, 12U), Coordinates(128U, 7U, -1), Coordinates(368U, -30, -5), BiStrides(14, 7, -2));
+ add_config(TensorShape(372U, 68U, 12U), Coordinates(128, 7, -1), Coordinates(368, -30, -5), BiStrides(14, 7, -2));
// 4D
- add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128U, 7U, 2U), Coordinates(368U, -30, 5U), BiStrides(20, 7, 2), 1, 1);
+ add_config(TensorShape(372U, 68U, 7U, 4U), Coordinates(128, 7, 2), Coordinates(368, -30, 5), BiStrides(20, 7, 2), 1, 1);
}
};
-
} // namespace datasets
} // namespace test
} // namespace arm_compute