diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-12-12 18:00:41 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-12-15 16:43:24 +0000 |
commit | 05c711e8941b05f6c9502fe8ef482a077aca508b (patch) | |
tree | a8fe7dfbc652a5c5005567dee752f51c1e5949b5 /verif/generator/tosa_error_if.py | |
parent | 3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f (diff) | |
download | reference_model-05c711e8941b05f6c9502fe8ef482a077aca508b.tar.gz |
Add extra control flow ERROR_IF tests
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a850699..c9d35c7 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -73,6 +73,9 @@ class ErrorIf(object): CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" U16InputZeroPointNotValid = "U16InputZeroPointNotValid" U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid" + CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" + CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" + CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" class TosaErrorIfArgGen: @@ -2191,6 +2194,47 @@ class TosaErrorValidator: return info_dict @staticmethod + def evCondIfCondNotMatchingBool(check=False, **kwargs): + error_name = ErrorIf.CondIfCondNotMatchingBool + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor does not match bool type" + + if check: + cond = kwargs["cond"] + if cond.dtype != DType.BOOL: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evCondIfCondShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondIfCondShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor is not equal to a size of one" + + if check: + cond = kwargs["cond"] + # Size of 1 is equivalent to rank 0 + if len(cond.shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod def evInputListOutputListMismatch(check=False, **kwargs): error_name = ErrorIf.InputListOutputListMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -2324,6 +2368,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondGraphOutputShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Cond graph output is not a shape of size one" + + if check: + basicBlocks = kwargs["basicBlocks"] + cond_block = basicBlocks[1] + cond_outputs = cond_block.outputs + cond_tens = cond_block.tensors + # Size of 1 is equivalent to rank 0 + if len(cond_tens[cond_outputs[0]].shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod |