From 05c711e8941b05f6c9502fe8ef482a077aca508b Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Mon, 12 Dec 2022 18:00:41 +0000 Subject: Add extra control flow ERROR_IF tests Signed-off-by: Jeremy Johnson Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf --- verif/generator/tosa_error_if.py | 68 ++++++++++++++++++++++++++++++++++++++++ verif/generator/tosa_test_gen.py | 46 ++++++++++++++++++++++----- verif/generator/tosa_utils.py | 3 ++ 3 files changed, 110 insertions(+), 7 deletions(-) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a850699..c9d35c7 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -73,6 +73,9 @@ class ErrorIf(object): CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" U16InputZeroPointNotValid = "U16InputZeroPointNotValid" U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid" + CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" + CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" + CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" class TosaErrorIfArgGen: @@ -2190,6 +2193,47 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evCondIfCondNotMatchingBool(check=False, **kwargs): + error_name = ErrorIf.CondIfCondNotMatchingBool + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor does not match bool type" + + if check: + cond = kwargs["cond"] + if cond.dtype != DType.BOOL: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evCondIfCondShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondIfCondShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Conditional tensor is not equal to a size of one" + + if check: + cond = kwargs["cond"] + # Size of 1 is equivalent to rank 0 + if len(cond.shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + @staticmethod def evInputListOutputListMismatch(check=False, **kwargs): error_name = ErrorIf.InputListOutputListMismatch @@ -2324,6 +2368,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs): + error_name = ErrorIf.CondGraphOutputShapeNotSizeOne + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Cond graph output is not a shape of size one" + + if check: + basicBlocks = kwargs["basicBlocks"] + cond_block = basicBlocks[1] + cond_outputs = cond_block.outputs + cond_tens = cond_block.tensors + # Size of 1 is equivalent to rank 0 + if len(cond_tens[cond_outputs[0]].shape) != 0: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index f3ca512..515e8bb 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES +from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes from generator.tosa_utils import vect_f32_to_bf16 @@ -1785,15 +1786,32 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens + def _get_condition_tensor(self, op, cond, error_name): + if error_name == ErrorIf.CondIfCondNotMatchingBool: + cond_type = get_wrong_output_type(op, self.rng, DType.BOOL) + else: + cond_type = DType.BOOL + if error_name == ErrorIf.CondIfCondShapeNotSizeOne: + choice = self.rng.choice([1, 2]) + if choice == 1: + cond_shape = [2] + else: + cond_shape = [1, 2] + else: + # Must be of size 1 (rank 0) + cond_shape = [] + cond_tens = self.ser.addConst(cond_shape, cond_type, [cond]) + return cond_tens + def build_cond_if_const( self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None ): # For cond_if with constants, we're supplied with then/else tensors that we ignore - # (except for the generated shap) and the condition. Build Then/Else blocks + # (except for the generated shape) and the condition. Build Then/Else blocks # and fill them with const nodes for the body. # Condition tensor - cond_tens = self.ser.addConst([], DType.BOOL, [cond]) + cond_tens = self._get_condition_tensor(op, cond, error_name) # Make then/else tensors out_shape = then_tens.shape @@ -1848,6 +1866,7 @@ class TosaTestGen: error_name, op=op, basicBlocks=self.ser.basicBlocks, + cond=cond_tens, ): return None @@ -1860,7 +1879,7 @@ class TosaTestGen: # alternately add or subtract them based on the condition # Condition tensor - cond_tens = self.ser.addConst([], DType.BOOL, [cond]) + cond_tens = self._get_condition_tensor(op, cond, error_name) result_tens = self.ser.addOutput(a.shape, a.dtype) @@ -1930,6 +1949,7 @@ class TosaTestGen: a=a, b=b, basicBlocks=self.ser.basicBlocks, + cond=cond_tens, ): return None @@ -1997,11 +2017,18 @@ class TosaTestGen: zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)]) if error_name == ErrorIf.CondGraphOutputNotMatchingBool: - cond_tens = self.ser.addOutput( - [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32]) - ) + cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32]) + else: + cond_type = DType.BOOL + if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne: + choice = self.rng.choice([1, 2]) + if choice == 1: + cond_shape = [3] + else: + cond_shape = [1, 2] else: - cond_tens = self.ser.addOutput([], DType.BOOL) + cond_shape = [] + cond_tens = self.ser.addOutput(cond_shape, cond_type) self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name]) @@ -3818,6 +3845,8 @@ class TosaTestGen: "error_if_validators": ( TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch, + TosaErrorValidator.evCondIfCondNotMatchingBool, + TosaErrorValidator.evCondIfCondShapeNotSizeOne, ), }, "cond_if_binary": { @@ -3835,6 +3864,8 @@ class TosaTestGen: TosaErrorValidator.evInputListElseGraphMismatch, TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch, + TosaErrorValidator.evCondIfCondNotMatchingBool, + TosaErrorValidator.evCondIfCondShapeNotSizeOne, ), }, # while_loop @@ -3854,6 +3885,7 @@ class TosaTestGen: TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch, TosaErrorValidator.evCondGraphOutputNotMatchingBool, + TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, } diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index d79ab3c..29ae898 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -142,6 +142,9 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT32, DType.INT48, ) + else: + # Assume all types but the input type are incorrect + incorrect_types = list(usableDTypes(excludes=(input_dtype,))) return rng.choice(a=incorrect_types) -- cgit v1.2.1