aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-12-12 18:00:41 +0000
committerEric Kunze <eric.kunze@arm.com>2022-12-15 16:43:24 +0000
commit05c711e8941b05f6c9502fe8ef482a077aca508b (patch)
treea8fe7dfbc652a5c5005567dee752f51c1e5949b5
parent3d3d45d669a460c6bc8e51b9dd9a8149c51e3d7f (diff)
downloadreference_model-05c711e8941b05f6c9502fe8ef482a077aca508b.tar.gz
Add extra control flow ERROR_IF tests
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf
-rw-r--r--verif/generator/tosa_error_if.py68
-rw-r--r--verif/generator/tosa_test_gen.py46
-rw-r--r--verif/generator/tosa_utils.py3
3 files changed, 110 insertions, 7 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a850699..c9d35c7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -73,6 +73,9 @@ class ErrorIf(object):
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
+ CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
+ CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
+ CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
class TosaErrorIfArgGen:
@@ -2191,6 +2194,47 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evCondIfCondNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor does not match bool type"
+
+ if check:
+ cond = kwargs["cond"]
+ if cond.dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor is not equal to a size of one"
+
+ if check:
+ cond = kwargs["cond"]
+ # Size of 1 is equivalent to rank 0
+ if len(cond.shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evInputListOutputListMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListOutputListMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2324,6 +2368,30 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a shape of size one"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ # Size of 1 is equivalent to rank 0
+ if len(cond_tens[cond_outputs[0]].shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index f3ca512..515e8bb 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from generator.tosa_utils import vect_f32_to_bf16
@@ -1785,15 +1786,32 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
+ def _get_condition_tensor(self, op, cond, error_name):
+ if error_name == ErrorIf.CondIfCondNotMatchingBool:
+ cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
+ else:
+ cond_type = DType.BOOL
+ if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
+ choice = self.rng.choice([1, 2])
+ if choice == 1:
+ cond_shape = [2]
+ else:
+ cond_shape = [1, 2]
+ else:
+ # Must be of size 1 (rank 0)
+ cond_shape = []
+ cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
+ return cond_tens
+
def build_cond_if_const(
self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
- # (except for the generated shap) and the condition. Build Then/Else blocks
+ # (except for the generated shape) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
# Condition tensor
- cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+ cond_tens = self._get_condition_tensor(op, cond, error_name)
# Make then/else tensors
out_shape = then_tens.shape
@@ -1848,6 +1866,7 @@ class TosaTestGen:
error_name,
op=op,
basicBlocks=self.ser.basicBlocks,
+ cond=cond_tens,
):
return None
@@ -1860,7 +1879,7 @@ class TosaTestGen:
# alternately add or subtract them based on the condition
# Condition tensor
- cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+ cond_tens = self._get_condition_tensor(op, cond, error_name)
result_tens = self.ser.addOutput(a.shape, a.dtype)
@@ -1930,6 +1949,7 @@ class TosaTestGen:
a=a,
b=b,
basicBlocks=self.ser.basicBlocks,
+ cond=cond_tens,
):
return None
@@ -1997,11 +2017,18 @@ class TosaTestGen:
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
- cond_tens = self.ser.addOutput(
- [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
- )
+ cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
+ else:
+ cond_type = DType.BOOL
+ if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
+ choice = self.rng.choice([1, 2])
+ if choice == 1:
+ cond_shape = [3]
+ else:
+ cond_shape = [1, 2]
else:
- cond_tens = self.ser.addOutput([], DType.BOOL)
+ cond_shape = []
+ cond_tens = self.ser.addOutput(cond_shape, cond_type)
self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
@@ -3818,6 +3845,8 @@ class TosaTestGen:
"error_if_validators": (
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
+ TosaErrorValidator.evCondIfCondNotMatchingBool,
+ TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
"cond_if_binary": {
@@ -3835,6 +3864,8 @@ class TosaTestGen:
TosaErrorValidator.evInputListElseGraphMismatch,
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
+ TosaErrorValidator.evCondIfCondNotMatchingBool,
+ TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
# while_loop
@@ -3854,6 +3885,7 @@ class TosaTestGen:
TosaErrorValidator.evInputListBodyGraphInputMismatch,
TosaErrorValidator.evInputListBodyGraphOutputMismatch,
TosaErrorValidator.evCondGraphOutputNotMatchingBool,
+ TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
}
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index d79ab3c..29ae898 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -142,6 +142,9 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT32,
DType.INT48,
)
+ else:
+ # Assume all types but the input type are incorrect
+ incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
return rng.choice(a=incorrect_types)