diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 515e8bb..d799eb0 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import os from copy import deepcopy @@ -1845,7 +1845,7 @@ class TosaTestGen: # Finally, build the op and the two blocks self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) - self.ser.startBasicBlock(then_block) + self.ser.addBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) @@ -1853,7 +1853,7 @@ class TosaTestGen: then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) self.ser.addOutputTensor(then_tens) - self.ser.startBasicBlock(else_block) + self.ser.addBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) else: @@ -1865,7 +1865,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -1914,7 +1914,7 @@ class TosaTestGen: assert False, f"No tests for DType: {a.dtype}" for block, op in ((then_block, then_op), (else_block, else_op)): - self.ser.startBasicBlock(block) + self.ser.addBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block @@ -1948,7 +1948,7 @@ class TosaTestGen: op=op, a=a, b=b, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -2005,7 +2005,8 @@ class TosaTestGen: incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) # COND block (input: iter, output: cond_tens ) - self.ser.startBasicBlock(cond_block) + self.ser.addBasicBlock(cond_block) + if error_name == ErrorIf.InputListCondGraphMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2034,7 +2035,8 @@ class TosaTestGen: # BODY block (input: a, acc, iter, output: a, acc, iter) # Note that local intermediate tensors need to be declared here for the outputs - self.ser.startBasicBlock(body_block) + self.ser.addBasicBlock(body_block) + if error_name == ErrorIf.InputListBodyGraphInputMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2068,7 +2070,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, ): return None |