aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py46
1 files changed, 39 insertions, 7 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index f3ca512..515e8bb 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from generator.tosa_utils import vect_f32_to_bf16
@@ -1785,15 +1786,32 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
+ def _get_condition_tensor(self, op, cond, error_name):
+ if error_name == ErrorIf.CondIfCondNotMatchingBool:
+ cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
+ else:
+ cond_type = DType.BOOL
+ if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
+ choice = self.rng.choice([1, 2])
+ if choice == 1:
+ cond_shape = [2]
+ else:
+ cond_shape = [1, 2]
+ else:
+ # Must be of size 1 (rank 0)
+ cond_shape = []
+ cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
+ return cond_tens
+
def build_cond_if_const(
self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
- # (except for the generated shap) and the condition. Build Then/Else blocks
+ # (except for the generated shape) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
# Condition tensor
- cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+ cond_tens = self._get_condition_tensor(op, cond, error_name)
# Make then/else tensors
out_shape = then_tens.shape
@@ -1848,6 +1866,7 @@ class TosaTestGen:
error_name,
op=op,
basicBlocks=self.ser.basicBlocks,
+ cond=cond_tens,
):
return None
@@ -1860,7 +1879,7 @@ class TosaTestGen:
# alternately add or subtract them based on the condition
# Condition tensor
- cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+ cond_tens = self._get_condition_tensor(op, cond, error_name)
result_tens = self.ser.addOutput(a.shape, a.dtype)
@@ -1930,6 +1949,7 @@ class TosaTestGen:
a=a,
b=b,
basicBlocks=self.ser.basicBlocks,
+ cond=cond_tens,
):
return None
@@ -1997,11 +2017,18 @@ class TosaTestGen:
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
- cond_tens = self.ser.addOutput(
- [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
- )
+ cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
+ else:
+ cond_type = DType.BOOL
+ if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
+ choice = self.rng.choice([1, 2])
+ if choice == 1:
+ cond_shape = [3]
+ else:
+ cond_shape = [1, 2]
else:
- cond_tens = self.ser.addOutput([], DType.BOOL)
+ cond_shape = []
+ cond_tens = self.ser.addOutput(cond_shape, cond_type)
self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
@@ -3818,6 +3845,8 @@ class TosaTestGen:
"error_if_validators": (
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
+ TosaErrorValidator.evCondIfCondNotMatchingBool,
+ TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
"cond_if_binary": {
@@ -3835,6 +3864,8 @@ class TosaTestGen:
TosaErrorValidator.evInputListElseGraphMismatch,
TosaErrorValidator.evOutputListThenGraphMismatch,
TosaErrorValidator.evOutputListElseGraphMismatch,
+ TosaErrorValidator.evCondIfCondNotMatchingBool,
+ TosaErrorValidator.evCondIfCondShapeNotSizeOne,
),
},
# while_loop
@@ -3854,6 +3885,7 @@ class TosaTestGen:
TosaErrorValidator.evInputListBodyGraphInputMismatch,
TosaErrorValidator.evInputListBodyGraphOutputMismatch,
TosaErrorValidator.evCondGraphOutputNotMatchingBool,
+ TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
}