aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a850699..c9d35c7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -73,6 +73,9 @@ class ErrorIf(object):
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
+ CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
+ CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
+ CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
class TosaErrorIfArgGen:
@@ -2191,6 +2194,47 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evCondIfCondNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor does not match bool type"
+
+ if check:
+ cond = kwargs["cond"]
+ if cond.dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor is not equal to a size of one"
+
+ if check:
+ cond = kwargs["cond"]
+ # Size of 1 is equivalent to rank 0
+ if len(cond.shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evInputListOutputListMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListOutputListMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2324,6 +2368,30 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a shape of size one"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ # Size of 1 is equivalent to rank 0
+ if len(cond_tens[cond_outputs[0]].shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod