From cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 7 Sep 2023 20:49:09 +0000 Subject: [reference_model] Support StatefulOps and the tests for CallOnceOp Signed-off-by: Jerry Ge Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2 --- verif/frameworks/test_builder.py | 38 +++++++++++++++++++++- verif/frameworks/tosa_verif_framework_generator.py | 33 +++++++++++++++++-- 2 files changed, 67 insertions(+), 4 deletions(-) (limited to 'verif/frameworks') diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index fcd72a3..3554e40 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1175,7 +1175,7 @@ class TBuilder: return result[0] - class LSTM: + class LSTM(tf.Module): def __init__(self, name): self.result_name = name self.lstm = tf.keras.layers.LSTM( @@ -1191,6 +1191,23 @@ class TBuilder: def eval(self, a): return self.lstm(a) + class SLSTM(tf.Module): + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + stateful=True, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + class GRU: def __init__(self, name): self.result_name = name @@ -1256,3 +1273,22 @@ class TBuilder: def eval(self, a): return tf.broadcast_to(a, shape=self.shape, name=self.result_name) + + class CallOnce(tf.Module): + def __init__(self, name): + print(tf.__version__) + self.result_name = name + self.var = tf.Variable([1.0]) + + @tf.function( + input_signature=[ + tf.TensorSpec( + shape=[ + 1, + ], + dtype=tf.float32, + ) + ] + ) + def eval(self, a): + return self.var.assign([2.0]) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index ec009c6..ffe373b 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402 get_tf_dtype, get_shape_str, ) # noqa: E402 + from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402 # All of the supported frameworks @@ -829,6 +830,15 @@ TF_OP_LIST = { ] }, }, + "lstm_stateful": { + "operands": (1, 0), + "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, "gru": { "operands": (1, 0), "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), @@ -848,6 +858,17 @@ TF_OP_LIST = { ] }, }, + "callonce": { + "operands": (1, 0), + "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": [tf.float32], + }, + "custom_shapes": { + "custom_shape_only": True, + "shape_list": [(1,)], + }, + }, "rfft2d": { "operands": (1, 0), "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), @@ -1219,9 +1240,15 @@ def run_unit_test( if "tflite" not in excluded_framework_list: # Convert the model to TFLite flatbuffer module = tf.Module() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function], module - ) + + if op_name == "callonce" or op_name == "lstm_stateful": + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], fcn_node + ) + else: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], module + ) converter.experimental_new_converter = True -- cgit v1.2.1