aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-09-15 18:10:01 +0000
commitcf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch)
treeaff6bab02c36c095a62381ac8f68d185bdccbe73 /verif/frameworks/test_builder.py
parent00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff)
downloadreference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py38
1 files changed, 37 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index fcd72a3..3554e40 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1175,7 +1175,7 @@ class TBuilder:
return result[0]
- class LSTM:
+ class LSTM(tf.Module):
def __init__(self, name):
self.result_name = name
self.lstm = tf.keras.layers.LSTM(
@@ -1191,6 +1191,23 @@ class TBuilder:
def eval(self, a):
return self.lstm(a)
+ class SLSTM(tf.Module):
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ stateful=True,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
class GRU:
def __init__(self, name):
self.result_name = name
@@ -1256,3 +1273,22 @@ class TBuilder:
def eval(self, a):
return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
+
+ class CallOnce(tf.Module):
+ def __init__(self, name):
+ print(tf.__version__)
+ self.result_name = name
+ self.var = tf.Variable([1.0])
+
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(
+ shape=[
+ 1,
+ ],
+ dtype=tf.float32,
+ )
+ ]
+ )
+ def eval(self, a):
+ return self.var.assign([2.0])