aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PoolingLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 5ce4aa6755..775c4125fc 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -128,6 +128,17 @@ public:
PoolingLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, pool_type, pool_size, pad_stride_info, data_type, 0);
}
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class GlobalPoolingLayerValidationFixture : public PoolingLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape, PoolingType pool_type, DataType data_type)
+ {
+ PoolingLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, pool_type, shape.x(), PadStrideInfo(1, 1, 0, 0), data_type, 0);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute