aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tosa_verif_framework_generator.py
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-09-15 18:10:01 +0000
commitcf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch)
treeaff6bab02c36c095a62381ac8f68d185bdccbe73 /verif/frameworks/tosa_verif_framework_generator.py
parent00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff)
downloadreference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py33
1 files changed, 30 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ec009c6..ffe373b 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402
get_tf_dtype,
get_shape_str,
) # noqa: E402
+
from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
# All of the supported frameworks
@@ -829,6 +830,15 @@ TF_OP_LIST = {
]
},
},
+ "lstm_stateful": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
"gru": {
"operands": (1, 0),
"build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
@@ -848,6 +858,17 @@ TF_OP_LIST = {
]
},
},
+ "callonce": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.float32],
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1,)],
+ },
+ },
"rfft2d": {
"operands": (1, 0),
"build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
@@ -1219,9 +1240,15 @@ def run_unit_test(
if "tflite" not in excluded_framework_list:
# Convert the model to TFLite flatbuffer
module = tf.Module()
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
- [concrete_function], module
- )
+
+ if op_name == "callonce" or op_name == "lstm_stateful":
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], fcn_node
+ )
+ else:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], module
+ )
converter.experimental_new_converter = True