aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/ShapeDatasets.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/ShapeDatasets.h')
-rw-r--r--tests/datasets/ShapeDatasets.h70
1 files changed, 70 insertions, 0 deletions
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h
index e4277a981e..047457c99e 100644
--- a/tests/datasets/ShapeDatasets.h
+++ b/tests/datasets/ShapeDatasets.h
@@ -212,6 +212,25 @@ public:
}
};
+/** Data set containing small tensor shapes. */
+class SmallShapesNoBatches final : public ShapeDataset
+{
+public:
+ SmallShapesNoBatches()
+ : ShapeDataset("Shape",
+ {
+ // Batch size 1
+ TensorShape{ 3U, 11U },
+ TensorShape{ 1U, 16U },
+ TensorShape{ 27U, 13U, 7U },
+ TensorShape{ 7U, 7U, 17U },
+ TensorShape{ 33U, 13U, 2U },
+ TensorShape{ 11U, 11U, 3U }
+ })
+ {
+ }
+};
+
/** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */
class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
{
@@ -282,6 +301,44 @@ public:
}
};
+class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
+{
+public:
+ TemporaryLimitedSmallShapesBroadcast()
+ : ZipDataset<ShapeDataset, ShapeDataset>(
+ ShapeDataset("Shape0",
+ {
+ TensorShape{ 9U, 9U, 5U },
+ TensorShape{ 27U, 13U, 2U },
+ }),
+ ShapeDataset("Shape1",
+ {
+ TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z
+ TensorShape{ 27U, 1U, 1U }, // Broadcast in Y and Z
+ }))
+ {
+ }
+};
+
+class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
+{
+public:
+ TemporaryLimitedLargeShapesBroadcast()
+ : ZipDataset<ShapeDataset, ShapeDataset>(
+ ShapeDataset("Shape0",
+ {
+ TensorShape{ 127U, 25U, 5U },
+ TensorShape{ 485, 40U, 10U }
+ }),
+ ShapeDataset("Shape1",
+ {
+ TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z
+ TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z
+ }))
+ {
+ }
+};
+
/** Data set containing medium tensor shapes. */
class MediumShapes final : public ShapeDataset
{
@@ -359,6 +416,19 @@ public:
}
};
+/** Data set containing large tensor shapes. */
+class LargeShapesNoBatches final : public ShapeDataset
+{
+public:
+ LargeShapesNoBatches()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 582U, 131U, 2U },
+ })
+ {
+ }
+};
+
/** Data set containing pairs of large tensor shapes that are broadcast compatible. */
class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset>
{