aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-10-14 15:05:41 +0100
committerEric Kunze <eric.kunze@arm.com>2021-11-09 15:11:29 +0000
commit630c17c5b46aed13edebc60321fcee5659c688bb (patch)
treea6e52eaa5b2794454fbb53334e50fb929a254dcc
parentbb5676f55df0d14be7e07981c39645971a587ed2 (diff)
downloadreference_model-630c17c5b46aed13edebc60321fcee5659c688bb.tar.gz
Add negative testing to cond_if, while_loop
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Signed-off-by: Les Bell <les.bell@arm.com> Change-Id: Ie6c8c8653874f9eed6007a54a3ad526601a4a669
-rw-r--r--verif/tosa_error_if.py9
-rw-r--r--verif/tosa_test_gen.py422
2 files changed, 395 insertions, 36 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 9fcc374..f0e752f 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -56,5 +56,14 @@ class ErrorIf(object):
MaxSmallerMin = "MaxSmallerMin"
ConcatInputRankMismatch = "ConcatInputRankMismatch"
ConcatInputDimMismatch = "ConcatInputDimMismatch"
+ CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
+ CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
+ CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
+ CondIfOutputListElseGraphMismatch = "CondIfOutputListElseGraphMismatch"
+ InputListOutputListMismatch = "InputListOutputListMismatch"
+ InputListCondGraphMismatch = "InputListCondGraphMismatch"
+ InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
+ InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
+ CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index cd59898..4e944ea 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -29,6 +29,7 @@ import threading
import traceback
import math
import itertools
+from copy import deepcopy
from enum import IntEnum, Enum, unique
from tosa_ref_run import TosaReturnCode
@@ -48,6 +49,13 @@ DType = tosa.DType.DType()
Op = tosa.Op.Op()
ResizeMode = tosa.ResizeMode.ResizeMode()
+
+def product(shape):
+ value = 1
+ for n in shape:
+ value *= n
+ return value
+
class TosaQuantGen:
"""QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
@@ -185,8 +193,9 @@ class TosaTensorGen:
pl, const = opName["operands"]
shape = testGen.makeShape(rank)
- # Constrict dimension size for large ranks when creating WrongRank tests
- shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
shape_list = []
for i in range(pl + const):
@@ -213,8 +222,9 @@ class TosaTensorGen:
if testGen.args.max_batch_size:
shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
- # Constrict dimension size for large ranks when creating WrongRank tests
- shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
shape_list = []
for i in range(pl + const):
@@ -404,8 +414,9 @@ class TosaTensorGen:
input_shape = testGen.makeShape(rank)
- # Constrict dimension size for large ranks when creating WrongRank tests
- shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
filter_oc = testGen.rng.integers(
low=testGen.args.tensor_shape_range[0],
@@ -428,8 +439,9 @@ class TosaTensorGen:
a_shape = testGen.makeShape(rank)
- # Constrict dimension size for large ranks when creating WrongRank tests
- shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
# Get a random number for b_oc even if target shape is defined
b_oc = np.int32(
@@ -1405,17 +1417,13 @@ class TosaErrorIfArgGen:
output_list = []
return input_list, output_list
-
@staticmethod
- def eiRestrictDimension(shape, error_name):
- # Restrict dimension size if rank is large for WrongRank Error_If
- # This will keep the test sizes reasonably small
- if error_name == ErrorIf.WrongRank:
- if len(shape) > 4:
- shape[4] = 1
-
- return shape
-
+ def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
+ """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
+ new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
+ while product(new_shape) > max_items:
+ new_shape = [max(d - 1, 1) for d in new_shape]
+ return new_shape
def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
if error_name == ErrorIf.StartSmallerZero:
@@ -2705,6 +2713,243 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evInputListThenGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfInputListThenGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list shape does not match then-graph shape"
+
+ if check:
+ a = kwargs['a']
+ b = kwargs['b']
+ basicBlocks = kwargs['basicBlocks']
+ then_block = basicBlocks[1]
+ then_inputs = then_block.inputs
+ then_tens = then_block.tensors
+ if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evInputListElseGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfInputListElseGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list shape does not match else-graph shape"
+
+ if check:
+ a = kwargs['a']
+ b = kwargs['b']
+ basicBlocks = kwargs['basicBlocks']
+ else_block = basicBlocks[2]
+ else_inputs = else_block.inputs
+ else_tens = else_block.tensors
+ if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evOutputListThenGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfOutputListThenGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output list shape does not match then-graph shape"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ cond_block = basicBlocks[0]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ then_block = basicBlocks[1]
+ then_outputs = then_block.outputs
+ then_tens = then_block.tensors
+ if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evOutputListElseGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfOutputListElseGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output list shape does not match else-graph shape"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ cond_block = basicBlocks[0]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ else_block = basicBlocks[2]
+ else_outputs = else_block.outputs
+ else_tens = else_block.tensors
+ if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evInputListOutputListMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListOutputListMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match output list"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_outputs = while_block.outputs
+ while_tens = while_block.tensors
+ if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evInputListCondGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListCondGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match cond graph"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ cond_block = basicBlocks[1]
+ cond_inputs = cond_block.inputs
+ cond_tens = cond_block.tensors
+ if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
+ (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evInputListBodyGraphInputMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListBodyGraphInputMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match body graph input"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ body_block = basicBlocks[2]
+ body_outputs = body_block.inputs
+ body_tens = body_block.tensors
+ if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
+ (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListBodyGraphOutputMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match body graph output"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ body_block = basicBlocks[2]
+ body_outputs = body_block.outputs
+ body_tens = body_block.tensors
+ if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
+ (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+ error_result = True
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a match list of booleans"
+
+ if check:
+ basicBlocks = kwargs['basicBlocks']
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
class TosaInvalidValidator:
@@ -4131,7 +4376,7 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_cond_if_const(self, op, then_tens, else_tens, cond):
+ def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
# (except for the generated shap) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
@@ -4141,6 +4386,14 @@ class TosaTestGen:
# Make then/else tensors
out_shape = then_tens.shape
+
+ # Create an incorrect output shape for error_if tests
+ if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
+ incorrect_shape = deepcopy(then_tens.shape)
+ for i in range(len(incorrect_shape)):
+ incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
+ incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
+
then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
@@ -4158,16 +4411,30 @@ class TosaTestGen:
self.ser.startBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
- then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
+ if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
+ then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+ else:
+ then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
self.ser.startBasicBlock(else_block)
- else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
+ if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
+ else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+ else:
+ else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
self.ser.addOutputTensor(else_tens)
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ basicBlocks=self.ser.basicBlocks
+ )
+
return result_tens
- def build_cond_if_binary(self, op, a, b, cond):
+ def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
# For cond_if with a binary op in the then/else blocks, take a and b and
# alternately add or subtract them based on the condition
@@ -4182,6 +4449,15 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
+ if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
+ ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
+ incorrect_shape = a.shape.copy()
+ for i in range(len(incorrect_shape)):
+ incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
+ incorrect_block_input = deepcopy(a)
+ incorrect_block_input.shape = incorrect_shape
+
+
# Finally, build the op and the two blocks
self.ser.addOperator(
op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
@@ -4196,14 +4472,35 @@ class TosaTestGen:
for block, op in ((then_block, then_op), (else_block, else_op)):
self.ser.startBasicBlock(block)
- self.ser.addInputTensor(a)
- self.ser.addInputTensor(b)
- tens = self.ser.addOutput(a.shape, a.dtype)
+ if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
+ (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
+ self.ser.addInputTensor(incorrect_block_input)
+ self.ser.addInputTensor(b)
+ tens = self.ser.addOutput(a.shape, a.dtype)
+ elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
+ (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(b)
+ tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
+ else:
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(b)
+ tens = self.ser.addOutput(a.shape, a.dtype)
self.ser.addOperator(op, [a.name, b.name], [tens.name])
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ a=a,
+ b=b,
+ basicBlocks=self.ser.basicBlocks
+ )
+
return result_tens
- def build_while_loop(self, op, a, iter_val):
+ def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
cond_block = "COND_BLOCK"
@@ -4220,7 +4517,13 @@ class TosaTestGen:
# Intermediate/output tensors for everything going through the loop
iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
a_out = self.ser.addIntermediate(a.shape, a.dtype)
- acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+ if error_name == ErrorIf.InputListOutputListMismatch:
+ incorrect_acc = deepcopy(acc)
+ for i in range(len(incorrect_acc.shape)):
+ incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
+ acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
+ else:
+ acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
# While_loop operator
self.ser.addOperator(
@@ -4231,30 +4534,71 @@ class TosaTestGen:
)
self.ser.addOutputTensor(acc_out)
+ if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
+ incorrect_iter = deepcopy(iter)
+ for i in range(len(incorrect_iter.shape)):
+ incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
+ if len(incorrect_iter.shape) == 0:
+ incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
+
+ incorrect_acc = deepcopy(acc)
+ for i in range(len(incorrect_acc.shape)):
+ incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
+
# COND block (input: iter, output: cond_tens )
self.ser.startBasicBlock(cond_block)
- self.ser.addInputTensor(iter)
- self.ser.addInputTensor(a)
- self.ser.addInputTensor(acc)
+ if error_name == ErrorIf.InputListCondGraphMismatch:
+ self.ser.addInputTensor(incorrect_iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(incorrect_acc)
+ else:
+ self.ser.addInputTensor(iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(acc)
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
- cond_tens = self.ser.addOutput([], DType.BOOL)
+
+ if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
+ cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
+ else:
+ cond_tens = self.ser.addOutput([], DType.BOOL)
+
self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
self.ser.startBasicBlock(body_block)
- self.ser.addInputTensor(iter)
- self.ser.addInputTensor(a)
- self.ser.addInputTensor(acc)
+ if error_name == ErrorIf.InputListBodyGraphInputMismatch:
+ self.ser.addInputTensor(incorrect_iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(incorrect_acc)
+ else:
+ self.ser.addInputTensor(iter)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(acc)
+
one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
- iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
- acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+
+ if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
+ iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
+ acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
+ else:
+ iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
+ acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+
self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
self.ser.addOutputTensor(iter_body_out)
self.ser.addOutputTensor(a)
self.ser.addOutputTensor(acc_body_out)
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ basicBlocks=self.ser.basicBlocks
+ )
+
return acc_out
def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
@@ -5445,6 +5789,7 @@ class TosaTestGen:
TosaArgGen.agCondIf,
),
"types": [DType.BOOL],
+ "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
},
"cond_if_binary": {
"op": Op.COND_IF,
@@ -5455,6 +5800,8 @@ class TosaTestGen:
TosaArgGen.agCondIf,
),
"types": TYPE_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
+ TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
},
# while_loop
"while_loop": {
@@ -5466,6 +5813,9 @@ class TosaTestGen:
TosaArgGen.agWhileLoop,
),
"types": [DType.INT32],
+ "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
+ TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
+ TosaErrorValidator.evCondGraphOutputNotMatchingBool)
},
}