aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py38
1 files changed, 37 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index fcd72a3..3554e40 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1175,7 +1175,7 @@ class TBuilder:
return result[0]
- class LSTM:
+ class LSTM(tf.Module):
def __init__(self, name):
self.result_name = name
self.lstm = tf.keras.layers.LSTM(
@@ -1191,6 +1191,23 @@ class TBuilder:
def eval(self, a):
return self.lstm(a)
+ class SLSTM(tf.Module):
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ stateful=True,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
class GRU:
def __init__(self, name):
self.result_name = name
@@ -1256,3 +1273,22 @@ class TBuilder:
def eval(self, a):
return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
+
+ class CallOnce(tf.Module):
+ def __init__(self, name):
+ print(tf.__version__)
+ self.result_name = name
+ self.var = tf.Variable([1.0])
+
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(
+ shape=[
+ 1,
+ ],
+ dtype=tf.float32,
+ )
+ ]
+ )
+ def eval(self, a):
+ return self.var.assign([2.0])