aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 515e8bb..d799eb0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
@@ -1845,7 +1845,7 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
- self.ser.startBasicBlock(then_block)
+ self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
@@ -1853,7 +1853,7 @@ class TosaTestGen:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
- self.ser.startBasicBlock(else_block)
+ self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
@@ -1865,7 +1865,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -1914,7 +1914,7 @@ class TosaTestGen:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
- self.ser.startBasicBlock(block)
+ self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
@@ -1948,7 +1948,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -2005,7 +2005,8 @@ class TosaTestGen:
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
- self.ser.startBasicBlock(cond_block)
+ self.ser.addBasicBlock(cond_block)
+
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2034,7 +2035,8 @@ class TosaTestGen:
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
- self.ser.startBasicBlock(body_block)
+ self.ser.addBasicBlock(body_block)
+
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2068,7 +2070,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
):
return None