diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a850699..c9d35c7 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -73,6 +73,9 @@ class ErrorIf(object): CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" U16InputZeroPointNotValid = "U16InputZeroPointNotValid" U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid" + CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" + CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" + CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" class TosaErrorIfArgGen: @@ -2191,6 +2194,47 @@ class TosaErrorValidator: return info_dict @staticmethod + def evCondIfCondNotMatchingBool(check=False, **kwargs): + error_name = ErrorIf.CondIfCondNotMatchingBool + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor does not match bool type" + + if check: + cond = kwargs["cond"] + if cond.dtype != DType.BOOL: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evCondIfCondShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondIfCondShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor is not equal to a size of one" + + if check: + cond = kwargs["cond"] + # Size of 1 is equivalent to rank 0 + if len(cond.shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod def evInputListOutputListMismatch(check=False, **kwargs): error_name = ErrorIf.InputListOutputListMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -2324,6 +2368,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondGraphOutputShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Cond graph output is not a shape of size one" + + if check: + basicBlocks = kwargs["basicBlocks"] + cond_block = basicBlocks[1] + cond_outputs = cond_block.outputs + cond_tens = cond_block.tensors + # Size of 1 is equivalent to rank 0 + if len(cond_tens[cond_outputs[0]].shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod |