aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tosa_verif_framework_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py33
1 files changed, 30 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ec009c6..ffe373b 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402
get_tf_dtype,
get_shape_str,
) # noqa: E402
+
from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
# All of the supported frameworks
@@ -829,6 +830,15 @@ TF_OP_LIST = {
]
},
},
+ "lstm_stateful": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
"gru": {
"operands": (1, 0),
"build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
@@ -848,6 +858,17 @@ TF_OP_LIST = {
]
},
},
+ "callonce": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.float32],
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1,)],
+ },
+ },
"rfft2d": {
"operands": (1, 0),
"build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
@@ -1219,9 +1240,15 @@ def run_unit_test(
if "tflite" not in excluded_framework_list:
# Convert the model to TFLite flatbuffer
module = tf.Module()
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
- [concrete_function], module
- )
+
+ if op_name == "callonce" or op_name == "lstm_stateful":
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], fcn_node
+ )
+ else:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], module
+ )
converter.experimental_new_converter = True