aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets_new/ShapeDatasets.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets_new/ShapeDatasets.h')
-rw-r--r--tests/datasets_new/ShapeDatasets.h21
1 files changed, 18 insertions, 3 deletions
diff --git a/tests/datasets_new/ShapeDatasets.h b/tests/datasets_new/ShapeDatasets.h
index ba142cae0c..14f7851621 100644
--- a/tests/datasets_new/ShapeDatasets.h
+++ b/tests/datasets_new/ShapeDatasets.h
@@ -35,7 +35,7 @@ namespace test
{
namespace datasets
{
-/** Data set containing one 1D tensor shape. */
+/** Data set containing 1D tensor shapes. */
class Small1DShape final : public framework::dataset::SingletonDataset<TensorShape>
{
public:
@@ -48,7 +48,7 @@ public:
/** Parent type for all for shape datasets. */
using ShapeDataset = framework::dataset::ContainerDataset<std::vector<TensorShape>>;
-/** Data set containing two small 2D tensor shapes. */
+/** Data set containing small 2D tensor shapes. */
class Small2DShapes final : public ShapeDataset
{
public:
@@ -93,7 +93,7 @@ public:
}
};
-/** Data set containing two 2D large tensor shapes. */
+/** Data set containing large 2D tensor shapes. */
class Large2DShapes final : public ShapeDataset
{
public:
@@ -107,6 +107,21 @@ public:
{
}
};
+
+/** Data set containing small tensor shapes for direct convolution. */
+class SmallDirectConvolutionShapes final : public ShapeDataset
+{
+public:
+ SmallDirectConvolutionShapes()
+ : ShapeDataset("InputShape",
+ {
+ TensorShape{ 3U, 3U, 3U, 2U, 4U, 5U },
+ TensorShape{ 32U, 37U, 3U },
+ TensorShape{ 13U, 15U, 8U, 3U }
+ })
+ {
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute