aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/Pooling3dLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/Pooling3dLayerFixture.h')
-rw-r--r--tests/validation/fixtures/Pooling3dLayerFixture.h37
1 files changed, 25 insertions, 12 deletions
diff --git a/tests/validation/fixtures/Pooling3dLayerFixture.h b/tests/validation/fixtures/Pooling3dLayerFixture.h
index c1b3519e80..563f1dcced 100644
--- a/tests/validation/fixtures/Pooling3dLayerFixture.h
+++ b/tests/validation/fixtures/Pooling3dLayerFixture.h
@@ -46,10 +46,10 @@ class Pooling3dLayerValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, Pooling3dLayerInfo pool_info, DataType data_type)
+ void setup(TensorShape shape, Pooling3dLayerInfo pool_info, DataType data_type, QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo())
{
- _target = compute_target(shape, pool_info, data_type);
- _reference = compute_reference(shape, pool_info, data_type);
+ _target = compute_target(shape, pool_info, data_type, input_qinfo, output_qinfo);
+ _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo);
}
protected:
@@ -68,17 +68,17 @@ protected:
}
else // data type is quantized_asymmetric
{
- ARM_COMPUTE_ERROR("Passed Type Not Supported");
+ library->fill_tensor_uniform(tensor, 0);
}
}
TensorType compute_target(TensorShape shape, Pooling3dLayerInfo info,
- DataType data_type)
+ DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC);
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, DataLayout::NDHWC);
const TensorShape dst_shape = misc::shape_calculator::compute_pool3d_shape((src.info()->tensor_shape()), info);
- TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, DataLayout::NDHWC);
// Create and configure function
FunctionType pool_layer;
@@ -103,17 +103,17 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(TensorShape shape, Pooling3dLayerInfo info, DataType data_type)
+ SimpleTensor<T> compute_reference(TensorShape shape, Pooling3dLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Create reference
- SimpleTensor<T> src(shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC);
+ SimpleTensor<T> src(shape, data_type, 1, input_qinfo, DataLayout::NDHWC);
// Fill reference
fill(src);
- return reference::pooling_3d_layer<T>(src, info);
+ return reference::pooling_3d_layer<T>(src, info, output_qinfo);
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -129,6 +129,19 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class Pooling3dLayerValidationQuantizedFixture : public Pooling3dLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape, PoolingType pool_type, Size3D pool_size, Size3D stride, Padding3D padding, bool exclude_padding, DataType data_type,
+ QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo())
+ {
+ Pooling3dLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, Pooling3dLayerInfo(pool_type, pool_size, stride, padding, exclude_padding),
+ data_type, input_qinfo, output_qinfo);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class Pooling3dLayerGlobalValidationFixture : public Pooling3dLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public: