diff options
author | Tai Ly <tai.ly@arm.com> | 2023-09-07 20:49:09 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-09-15 18:10:01 +0000 |
commit | cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch) | |
tree | aff6bab02c36c095a62381ac8f68d185bdccbe73 /verif/frameworks/test_builder.py | |
parent | 00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff) | |
download | reference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz |
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r-- | verif/frameworks/test_builder.py | 38 |
1 files changed, 37 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index fcd72a3..3554e40 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1175,7 +1175,7 @@ class TBuilder: return result[0] - class LSTM: + class LSTM(tf.Module): def __init__(self, name): self.result_name = name self.lstm = tf.keras.layers.LSTM( @@ -1191,6 +1191,23 @@ class TBuilder: def eval(self, a): return self.lstm(a) + class SLSTM(tf.Module): + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + stateful=True, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + class GRU: def __init__(self, name): self.result_name = name @@ -1256,3 +1273,22 @@ class TBuilder: def eval(self, a): return tf.broadcast_to(a, shape=self.shape, name=self.result_name) + + class CallOnce(tf.Module): + def __init__(self, name): + print(tf.__version__) + self.result_name = name + self.var = tf.Variable([1.0]) + + @tf.function( + input_signature=[ + tf.TensorSpec( + shape=[ + 1, + ], + dtype=tf.float32, + ) + ] + ) + def eval(self, a): + return self.var.assign([2.0]) |