aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/ShapeDatasets.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/ShapeDatasets.h')
-rw-r--r--tests/datasets/ShapeDatasets.h32
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h
index 3dc4566e18..173ee74958 100644
--- a/tests/datasets/ShapeDatasets.h
+++ b/tests/datasets/ShapeDatasets.h
@@ -35,19 +35,22 @@ namespace test
{
namespace datasets
{
-/** Data set containing 1D tensor shapes. */
-class Small1DShape final : public framework::dataset::SingletonDataset<TensorShape>
+/** Parent type for all for shape datasets. */
+using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
+
+/** Data set containing small 1D tensor shapes. */
+class Small1DShapes final : public ShapeDataset
{
public:
- Small1DShape()
- : SingletonDataset("Shape", TensorShape{ 256U })
+ Small1DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 256U }
+ })
{
}
};
-/** Parent type for all for shape datasets. */
-using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
-
/** Data set containing small 2D tensor shapes. */
class Small2DShapes final : public ShapeDataset
{
@@ -169,6 +172,21 @@ public:
}
};
+/** Data set containing large 1D tensor shapes. */
+class Large1DShapes final : public ShapeDataset
+{
+public:
+ Large1DShapes()
+ : ShapeDataset("Shape",
+ {
+ TensorShape{ 1921U },
+ TensorShape{ 1245U },
+ TensorShape{ 4160U }
+ })
+ {
+ }
+};
+
/** Data set containing large 2D tensor shapes. */
class Large2DShapes final : public ShapeDataset
{