diff options
Diffstat (limited to 'tests/datasets/ShapeDatasets.h')
-rw-r--r-- | tests/datasets/ShapeDatasets.h | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/tests/datasets/ShapeDatasets.h b/tests/datasets/ShapeDatasets.h index e4277a981e..047457c99e 100644 --- a/tests/datasets/ShapeDatasets.h +++ b/tests/datasets/ShapeDatasets.h @@ -212,6 +212,25 @@ public: } }; +/** Data set containing small tensor shapes. */ +class SmallShapesNoBatches final : public ShapeDataset +{ +public: + SmallShapesNoBatches() + : ShapeDataset("Shape", + { + // Batch size 1 + TensorShape{ 3U, 11U }, + TensorShape{ 1U, 16U }, + TensorShape{ 27U, 13U, 7U }, + TensorShape{ 7U, 7U, 17U }, + TensorShape{ 33U, 13U, 2U }, + TensorShape{ 11U, 11U, 3U } + }) + { + } +}; + /** Data set containing pairs of tiny tensor shapes that are broadcast compatible. */ class TinyShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> { @@ -282,6 +301,44 @@ public: } }; +class TemporaryLimitedSmallShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> +{ +public: + TemporaryLimitedSmallShapesBroadcast() + : ZipDataset<ShapeDataset, ShapeDataset>( + ShapeDataset("Shape0", + { + TensorShape{ 9U, 9U, 5U }, + TensorShape{ 27U, 13U, 2U }, + }), + ShapeDataset("Shape1", + { + TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z + TensorShape{ 27U, 1U, 1U }, // Broadcast in Y and Z + })) + { + } +}; + +class TemporaryLimitedLargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> +{ +public: + TemporaryLimitedLargeShapesBroadcast() + : ZipDataset<ShapeDataset, ShapeDataset>( + ShapeDataset("Shape0", + { + TensorShape{ 127U, 25U, 5U }, + TensorShape{ 485, 40U, 10U } + }), + ShapeDataset("Shape1", + { + TensorShape{ 1U, 1U, 1U }, // Broadcast in X, Y, Z + TensorShape{ 485U, 1U, 1U }, // Broadcast in Y, Z + })) + { + } +}; + /** Data set containing medium tensor shapes. */ class MediumShapes final : public ShapeDataset { @@ -359,6 +416,19 @@ public: } }; +/** Data set containing large tensor shapes. */ +class LargeShapesNoBatches final : public ShapeDataset +{ +public: + LargeShapesNoBatches() + : ShapeDataset("Shape", + { + TensorShape{ 582U, 131U, 2U }, + }) + { + } +}; + /** Data set containing pairs of large tensor shapes that are broadcast compatible. */ class LargeShapesBroadcast final : public framework::dataset::ZipDataset<ShapeDataset, ShapeDataset> { |