aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/ShapeDatasets.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/ShapeDatasets.h')
-rw-r--r--tests/datasets/ShapeDatasets.h81
1 files changed, 81 insertions, 0 deletions
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h
index 79e052c697..dbcd9d5000 100644
--- a/tests/datasets/ShapeDatasets.h
+++ b/tests/datasets/ShapeDatasets.h
@@ -51,6 +51,19 @@ public:
}
};
+/** Data set containing tiny 2D tensor shapes. */
+class Tiny2DShapes final : public ShapeDataset
+{
+public:
+ Tiny2DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 7U, 7U },
+ TensorShape{ 11U, 13U },
+ })
+ {
+ }
+};
/** Data set containing small 2D tensor shapes. */
class Small2DShapes final : public ShapeDataset
{
@@ -66,6 +79,20 @@ public:
}
};
+/** Data set containing tiny 3D tensor shapes. */
+class Tiny3DShapes final : public ShapeDataset
+{
+public:
+ Tiny3DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 7U, 7U, 5U },
+ TensorShape{ 23U, 13U, 9U },
+ })
+ {
+ }
+};
+
/** Data set containing small 3D tensor shapes. */
class Small3DShapes final : public ShapeDataset
{
@@ -81,6 +108,19 @@ public:
}
};
+/** Data set containing tiny 4D tensor shapes. */
+class Tiny4DShapes final : public ShapeDataset
+{
+public:
+ Tiny4DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 7U, 7U, 5U, 3U },
+ TensorShape{ 17U, 13U, 7U, 2U },
+ })
+ {
+ }
+};
/** Data set containing small 4D tensor shapes. */
class Small4DShapes final : public ShapeDataset
{
@@ -97,6 +137,20 @@ public:
};
/** Data set containing small tensor shapes. */
+class TinyShapes final : public ShapeDataset
+{
+public:
+ TinyShapes()
+ : ShapeDataset("Shape",
+ {
+ // Batch size 1
+ TensorShape{ 9U, 9U },
+ TensorShape{ 27U, 13U, 2U },
+ })
+ {
+ }
+};
+/** Data set containing small tensor shapes. */
class SmallShapes final : public ShapeDataset
{
public:
@@ -299,6 +353,20 @@ public:
}
};
+/** Data set containing tiny tensor shapes for direct convolution. */
+class TinyDirectConvolutionShapes final : public ShapeDataset
+{
+public:
+ TinyDirectConvolutionShapes()
+ : ShapeDataset("InputShape",
+ {
+ // Batch size 1
+ TensorShape{ 11U, 13U, 3U },
+ TensorShape{ 7U, 27U, 3U }
+ })
+ {
+ }
+};
/** Data set containing small tensor shapes for direct convolution. */
class SmallDirectConvolutionShapes final : public ShapeDataset
{
@@ -352,6 +420,19 @@ public:
{
}
};
+/** Data set containing tiny softmax layer shapes. */
+class SoftmaxLayerTinyShapes final : public ShapeDataset
+{
+public:
+ SoftmaxLayerTinyShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 9U, 9U },
+ TensorShape{ 128U, 10U, 2U },
+ })
+ {
+ }
+};
/** Data set containing small softmax layer shapes. */
class SoftmaxLayerSmallShapes final : public ShapeDataset