diff options
author | Jerry Ge <jerry.ge@arm.com> | 2022-10-27 09:57:00 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-01-13 19:09:21 +0000 |
commit | 9e94af8f10f0a21a117b3bc7ea42004844fdc3bb (patch) | |
tree | 868ab73bb67d4827963a4b43f28d8a8a49f50307 /verif/frameworks/test_builder.py | |
parent | dd8d9c251db0fece6453d86116052ad7f3e2d697 (diff) | |
download | reference_model-9e94af8f10f0a21a117b3bc7ea42004844fdc3bb.tar.gz |
Reference model update for control flow operators support
Rationale for making this change:
- In the original design, for control flow operators like WhileOp,
child blocks couldn't read the tensor variables (global consts) in the root level block,
this patch added the machanism for child blocks to access their parent
level block's tensors.
- This change also relies on another serialization change on adding
another layer of abtraction called Region:
- Serialization patch: [region] Add TosaSerializationRegion to serialization_lib
- Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib
- This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator
- Added the WhileOp related test cases: While, LSTM, GRU, RNN
- Other related fixes
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r-- | verif/frameworks/test_builder.py | 81 |
1 files changed, 80 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index c7c5cd7..8870f41 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -1164,3 +1164,82 @@ class TBuilder: def eval(self, a): return tf.bitwise.right_shift(a, self.shift, name=self.result_name) + + class While: + def __init__(self, name): + self.result_name = name + + def while_cond(self, x): + return tf.reduce_sum(x) < self.cap + + def while_body(self, x): + return tf.add(x, tf.math.sigmoid(x)) + + def eval(self, a): + self.cap = tf.cast( + tf.constant( + 2.0, + shape=[ + 1, + ], + ), + a.dtype, + ) + + result = tf.while_loop( + self.while_cond, self.while_body, [a], name=self.result_name + ) + + return result[0] + + class LSTM: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class GRU: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.GRU( + 2, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class RNN: + def __init__(self, name): + self.result_name = name + basic_cell = tf.keras.layers.SimpleRNNCell( + units=2, + activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + ) + self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False) + + def eval(self, a): + return self.rnn(a) + + class FullyConnected: + def __init__(self, name): + self.result_name = name + self.dense = tf.keras.layers.Dense(2) + + def eval(self, a): + return self.dense(a) |