From 01e1c1c7f965ceb07e78a3b1ad063161c0f47b94 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 7 Feb 2024 16:09:09 +0000 Subject: Improve Avg_Pool2D ErrorIf Testing * Add test for invalid accumulator dtype Signed-off-by: Jeremy Johnson Change-Id: I506e2047623372670b82db6e9c0010fa89802851 --- verif/generator/tosa_arg_gen.py | 3 +++ verif/generator/tosa_error_if.py | 44 ++++++++++++++++++++++++++++++++++++++++ verif/generator/tosa_test_gen.py | 2 ++ 3 files changed, 49 insertions(+) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 4630f35..33e74b5 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2485,6 +2485,9 @@ class TosaArgGen: # incorrect input data-type accum_dtypes = [DType.INT32] + if error_name == ErrorIf.WrongAccumulatorType: + accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes)) + if not test_level8k: if testGen.args.oversize: # add some oversize argument values diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 5fd647a..9a88acb 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -84,6 +84,7 @@ class ErrorIf(object): ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference" ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger" BroadcastShapesMismatch = "BroadcastShapesMismatch" + WrongAccumulatorType = "WrongAccumulatorType" class TosaErrorIfArgGen: @@ -2580,6 +2581,49 @@ class TosaErrorValidator: } return info_dict + def evWrongAccumulatorType(check=False, **kwargs): + error_name = ErrorIf.WrongAccumulatorType + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "An unsupported accumulator data type was requested" + + if check: + op = kwargs["op"] + input_dtype = kwargs["input_dtype"] + accum_dtype = kwargs["accum_dtype"] + if op["op"] == Op.AVG_POOL2D: + if ( + input_dtype + in ( + DType.INT8, + DType.INT16, + ) + and accum_dtype != DType.INT32 + ): + error_result = True + elif ( + input_dtype + in ( + DType.FP32, + DType.BF16, + ) + and accum_dtype != DType.FP32 + ): + error_result = True + elif input_dtype == DType.FP16 and accum_dtype not in ( + DType.FP16, + DType.FP32, + ): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index f5eca18..2d471c0 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -835,6 +835,7 @@ class TosaTestGen: input_dtype=input.dtype, output_shape=result_tensor.shape, output_dtype=result_tensor.dtype, + accum_dtype=accum_dtype, kernel=kernel, stride=stride, pad=pad, @@ -3218,6 +3219,7 @@ class TosaTestGen: TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch, TosaErrorValidator.evPoolingOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), -- cgit v1.2.1