aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PoolingLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h12
1 files changed, 12 insertions, 0 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 1813ef4c84..cdc2cae584 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -141,6 +141,18 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class PoolingLayerValidationMixedPrecisionFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool fp_mixed_precision = false)
+ {
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding, fp_mixed_precision),
+ data_type, data_layout);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public: