diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 5fd647a..9a88acb 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -84,6 +84,7 @@ class ErrorIf(object): ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference" ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger" BroadcastShapesMismatch = "BroadcastShapesMismatch" + WrongAccumulatorType = "WrongAccumulatorType" class TosaErrorIfArgGen: @@ -2580,6 +2581,49 @@ class TosaErrorValidator: } return info_dict + def evWrongAccumulatorType(check=False, **kwargs): + error_name = ErrorIf.WrongAccumulatorType + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "An unsupported accumulator data type was requested" + + if check: + op = kwargs["op"] + input_dtype = kwargs["input_dtype"] + accum_dtype = kwargs["accum_dtype"] + if op["op"] == Op.AVG_POOL2D: + if ( + input_dtype + in ( + DType.INT8, + DType.INT16, + ) + and accum_dtype != DType.INT32 + ): + error_result = True + elif ( + input_dtype + in ( + DType.FP32, + DType.BF16, + ) + and accum_dtype != DType.FP32 + ): + error_result = True + elif input_dtype == DType.FP16 and accum_dtype not in ( + DType.FP16, + DType.FP32, + ): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod |