aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/ShapeDatasets.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/ShapeDatasets.h')
-rw-r--r--tests/datasets/ShapeDatasets.h32
1 files changed, 32 insertions, 0 deletions
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h
index 4b563708e1..e939a6f5a7 100644
--- a/tests/datasets/ShapeDatasets.h
+++ b/tests/datasets/ShapeDatasets.h
@@ -238,6 +238,38 @@ public:
}
};
+/** Data set containing medium 3D tensor shapes. */
+class Medium3DShapes final : public ShapeDataset
+{
+public:
+ Medium3DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 42U, 37U, 8U },
+ TensorShape{ 57U, 60U, 13U },
+ TensorShape{ 128U, 64U, 21U },
+ TensorShape{ 83U, 72U, 14U }
+ })
+ {
+ }
+};
+
+/** Data set containing medium 4D tensor shapes. */
+class Medium4DShapes final : public ShapeDataset
+{
+public:
+ Medium4DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 42U, 37U, 8U, 15U },
+ TensorShape{ 57U, 60U, 13U, 8U },
+ TensorShape{ 128U, 64U, 21U, 13U },
+ TensorShape{ 83U, 72U, 14U, 5U }
+ })
+ {
+ }
+};
+
/** Data set containing large tensor shapes. */
class LargeShapes final : public ShapeDataset
{