diff options
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index ec009c6..ffe373b 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402 get_tf_dtype, get_shape_str, ) # noqa: E402 + from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402 # All of the supported frameworks @@ -829,6 +830,15 @@ TF_OP_LIST = { ] }, }, + "lstm_stateful": { + "operands": (1, 0), + "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, "gru": { "operands": (1, 0), "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), @@ -848,6 +858,17 @@ TF_OP_LIST = { ] }, }, + "callonce": { + "operands": (1, 0), + "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": [tf.float32], + }, + "custom_shapes": { + "custom_shape_only": True, + "shape_list": [(1,)], + }, + }, "rfft2d": { "operands": (1, 0), "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), @@ -1219,9 +1240,15 @@ def run_unit_test( if "tflite" not in excluded_framework_list: # Convert the model to TFLite flatbuffer module = tf.Module() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function], module - ) + + if op_name == "callonce" or op_name == "lstm_stateful": + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], fcn_node + ) + else: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], module + ) converter.experimental_new_converter = True |