aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 928ac0e..a03c66f 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -2934,17 +2934,19 @@ class TosaTestGen:
op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- self.ser.startBasicBlock(then_block)
- self.ser.addInputTensor(a)
- self.ser.addInputTensor(b)
- then_tens = self.ser.addOutput(a.shape, a.dtype)
- self.ser.addOperator(Op.ADD, [a.name, b.name], [then_tens.name])
+ if a.dtype in (DType.FLOAT, DType.INT32):
+ then_op, else_op = Op.ADD, Op.SUB
+ elif a.dtype in (DType.INT8, DType.INT16):
+ then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
+ else:
+ assert False, f"No tests for DType: {a.dtype}"
- self.ser.startBasicBlock(else_block)
- self.ser.addInputTensor(a)
- self.ser.addInputTensor(b)
- else_tens = self.ser.addOutput(a.shape, a.dtype)
- self.ser.addOperator(Op.SUB, [a.name, b.name], [else_tens.name])
+ for block, op in ((then_block, then_op), (else_block, else_op)):
+ self.ser.startBasicBlock(block)
+ self.ser.addInputTensor(a)
+ self.ser.addInputTensor(b)
+ tens = self.ser.addOutput(a.shape, a.dtype)
+ self.ser.addOperator(op, [a.name, b.name], [tens.name])
return result_tens
@@ -4117,7 +4119,7 @@ class TosaTestGen:
TosaTensorGen.tgBasic,
TosaArgGen.agCondIf,
),
- "types": TYPE_FI32,
+ "types": TYPE_INT_FP,
},
# while_loop
"while_loop": {