diff options
Diffstat (limited to 'tests/datasets_new/ShapeDatasets.h')
-rw-r--r-- | tests/datasets_new/ShapeDatasets.h | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/tests/datasets_new/ShapeDatasets.h b/tests/datasets_new/ShapeDatasets.h index ba142cae0c..14f7851621 100644 --- a/tests/datasets_new/ShapeDatasets.h +++ b/tests/datasets_new/ShapeDatasets.h @@ -35,7 +35,7 @@ namespace test { namespace datasets { -/** Data set containing one 1D tensor shape. */ +/** Data set containing 1D tensor shapes. */ class Small1DShape final : public framework::dataset::SingletonDataset<TensorShape> { public: @@ -48,7 +48,7 @@ public: /** Parent type for all for shape datasets. */ using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>; -/** Data set containing two small 2D tensor shapes. */ +/** Data set containing small 2D tensor shapes. */ class Small2DShapes final : public ShapeDataset { public: @@ -93,7 +93,7 @@ public: } }; -/** Data set containing two 2D large tensor shapes. */ +/** Data set containing large 2D tensor shapes. */ class Large2DShapes final : public ShapeDataset { public: @@ -107,6 +107,21 @@ public: { } }; + +/** Data set containing small tensor shapes for direct convolution. */ +class SmallDirectConvolutionShapes final : public ShapeDataset +{ +public: + SmallDirectConvolutionShapes() + : ShapeDataset("InputShape", + { + TensorShape{ 3U, 3U, 3U, 2U, 4U, 5U }, + TensorShape{ 32U, 37U, 3U }, + TensorShape{ 13U, 15U, 8U, 3U } + }) + { + } +}; } // namespace datasets } // namespace test } // namespace arm_compute |