aboutsummaryrefslogtreecommitdiff
path: root/tests/dataset/PoolingLayerDataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataset/PoolingLayerDataset.h')
-rw-r--r--tests/dataset/PoolingLayerDataset.h4
1 files changed, 3 insertions, 1 deletions
diff --git a/tests/dataset/PoolingLayerDataset.h b/tests/dataset/PoolingLayerDataset.h
index 5cdece4f66..1496cad379 100644
--- a/tests/dataset/PoolingLayerDataset.h
+++ b/tests/dataset/PoolingLayerDataset.h
@@ -134,7 +134,7 @@ public:
~GoogLeNetPoolingLayerDataset() = default;
};
-class RandomPoolingLayerDataset final : public PoolingLayerDataset<8>
+class RandomPoolingLayerDataset final : public PoolingLayerDataset<10>
{
public:
RandomPoolingLayerDataset()
@@ -148,6 +148,8 @@ public:
PoolingLayerDataObject{ TensorShape(13U, 13U, 32U), TensorShape(6U, 6U, 32U), PoolingLayerInfo(PoolingType::AVG, 3, PadStrideInfo(2, 2, 0, 0)) },
PoolingLayerDataObject{ TensorShape(24U, 24U, 10U), TensorShape(12U, 12U, 10U), PoolingLayerInfo(PoolingType::AVG, 2, PadStrideInfo(2, 2, 0, 0)) },
PoolingLayerDataObject{ TensorShape(8U, 8U, 30U), TensorShape(4U, 4U, 30U), PoolingLayerInfo(PoolingType::AVG, 2, PadStrideInfo(2, 2, 0, 0)) },
+ PoolingLayerDataObject{ TensorShape(7U, 7U, 10U), TensorShape(7U, 7U, 10U), PoolingLayerInfo(PoolingType::AVG, 3, PadStrideInfo(1, 1, 1, 1)) },
+ PoolingLayerDataObject{ TensorShape(7U, 7U, 10U), TensorShape(7U, 7U, 10U), PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(1, 1, 1, 1)) },
}
{
}