aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-12-12 18:00:41 +0000
committerEric Kunze <eric.kunze@arm.com>2022-12-15 16:43:24 +0000
commit05c711e8941b05f6c9502fe8ef482a077aca508b (patch)
treea8fe7dfbc652a5c5005567dee752f51c1e5949b5 /verif/generator/tosa_error_if.py
parent3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f (diff)
downloadreference_model-05c711e8941b05f6c9502fe8ef482a077aca508b.tar.gz
Add extra control flow ERROR_IF tests
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a850699..c9d35c7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -73,6 +73,9 @@ class ErrorIf(object):
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
+ CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
+ CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
+ CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
class TosaErrorIfArgGen:
@@ -2191,6 +2194,47 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evCondIfCondNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor does not match bool type"
+
+ if check:
+ cond = kwargs["cond"]
+ if cond.dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor is not equal to a size of one"
+
+ if check:
+ cond = kwargs["cond"]
+ # Size of 1 is equivalent to rank 0
+ if len(cond.shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evInputListOutputListMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListOutputListMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2324,6 +2368,30 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a shape of size one"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ # Size of 1 is equivalent to rank 0
+ if len(cond_tens[cond_outputs[0]].shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod