From 9e94af8f10f0a21a117b3bc7ea42004844fdc3bb Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Thu, 27 Oct 2022 09:57:00 -0700 Subject: Reference model update for control flow operators support Rationale for making this change: - In the original design, for control flow operators like WhileOp, child blocks couldn't read the tensor variables (global consts) in the root level block, this patch added the machanism for child blocks to access their parent level block's tensors. - This change also relies on another serialization change on adding another layer of abtraction called Region: - Serialization patch: [region] Add TosaSerializationRegion to serialization_lib - Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib - This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator - Added the WhileOp related test cases: While, LSTM, GRU, RNN - Other related fixes Signed-off-by: Jerry Ge Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab --- verif/generator/tosa_test_gen.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 515e8bb..d799eb0 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import os from copy import deepcopy @@ -1845,7 +1845,7 @@ class TosaTestGen: # Finally, build the op and the two blocks self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) - self.ser.startBasicBlock(then_block) + self.ser.addBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) @@ -1853,7 +1853,7 @@ class TosaTestGen: then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) self.ser.addOutputTensor(then_tens) - self.ser.startBasicBlock(else_block) + self.ser.addBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) else: @@ -1865,7 +1865,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -1914,7 +1914,7 @@ class TosaTestGen: assert False, f"No tests for DType: {a.dtype}" for block, op in ((then_block, then_op), (else_block, else_op)): - self.ser.startBasicBlock(block) + self.ser.addBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block @@ -1948,7 +1948,7 @@ class TosaTestGen: op=op, a=a, b=b, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, cond=cond_tens, ): return None @@ -2005,7 +2005,8 @@ class TosaTestGen: incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) # COND block (input: iter, output: cond_tens ) - self.ser.startBasicBlock(cond_block) + self.ser.addBasicBlock(cond_block) + if error_name == ErrorIf.InputListCondGraphMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2034,7 +2035,8 @@ class TosaTestGen: # BODY block (input: a, acc, iter, output: a, acc, iter) # Note that local intermediate tensors need to be declared here for the outputs - self.ser.startBasicBlock(body_block) + self.ser.addBasicBlock(body_block) + if error_name == ErrorIf.InputListBodyGraphInputMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) @@ -2068,7 +2070,7 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks, + basicBlocks=self.ser.currRegion.basicBlocks, ): return None -- cgit v1.2.1