diff options
author | Sang-Hoon Park <sang-hoon.park@arm.com> | 2019-09-18 13:39:00 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-01 12:02:45 +0000 |
commit | 2aa7fd011a4baff52dceb00a71b3674f819df8fc (patch) | |
tree | 081a8b0a75ff130d2c6179acf1fe1f1b58943412 /tests/validation/fixtures | |
parent | 5c4a8e96460eb83a6caef1c69ea5cbb4893858d7 (diff) | |
download | ComputeLibrary-2aa7fd011a4baff52dceb00a71b3674f819df8fc.tar.gz |
COMPMID-2601 [CL] add mixed precision support to PoolingLayer
* PoolingLayerInfo is updated with a new flag.
* CL Kernel is updated to use FP32 accumulation.
* CL pooling layer testscases are added for mixed precision.
* Reference pooling layer is updated to use FP32 accumulation.
Change-Id: I4ab2167cc7f86c86293cf50a0ca5119c04dc9c7e
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1973
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: VidhyaSudhan Loganathan <vidhyasudhan.loganathan@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/PoolingLayerFixture.h | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index 1813ef4c84..cdc2cae584 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -141,6 +141,18 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class PoolingLayerValidationMixedPrecisionFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool fp_mixed_precision = false) + { + PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding, fp_mixed_precision), + data_type, data_layout); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: |