diff options
author | Tai Ly <tai.ly@arm.com> | 2023-09-07 20:49:09 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-09-15 18:10:01 +0000 |
commit | cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch) | |
tree | aff6bab02c36c095a62381ac8f68d185bdccbe73 /verif/frameworks/tosa_verif_framework_generator.py | |
parent | 00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff) | |
download | reference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz |
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index ec009c6..ffe373b 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402 get_tf_dtype, get_shape_str, ) # noqa: E402 + from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402 # All of the supported frameworks @@ -829,6 +830,15 @@ TF_OP_LIST = { ] }, }, + "lstm_stateful": { + "operands": (1, 0), + "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, "gru": { "operands": (1, 0), "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), @@ -848,6 +858,17 @@ TF_OP_LIST = { ] }, }, + "callonce": { + "operands": (1, 0), + "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": [tf.float32], + }, + "custom_shapes": { + "custom_shape_only": True, + "shape_list": [(1,)], + }, + }, "rfft2d": { "operands": (1, 0), "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), @@ -1219,9 +1240,15 @@ def run_unit_test( if "tflite" not in excluded_framework_list: # Convert the model to TFLite flatbuffer module = tf.Module() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function], module - ) + + if op_name == "callonce" or op_name == "lstm_stateful": + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], fcn_node + ) + else: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], module + ) converter.experimental_new_converter = True |