aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-28 16:38:53 -0800
commit4762564da970eb1883a54aa66582e05c0dbd2b81 (patch)
tree657e14aa711f1c9e55a5fde15fac3f7f9f77e536 /verif
parent09ae449db8a45ab7c48af4541b43cb3dc80f9a30 (diff)
downloadreference_model-4762564da970eb1883a54aa66582e05c0dbd2b81.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'verif')
-rw-r--r--verif/frameworks/test_builder.py38
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py33
2 files changed, 67 insertions, 4 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index fcd72a3..3554e40 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1175,7 +1175,7 @@ class TBuilder:
return result[0]
- class LSTM:
+ class LSTM(tf.Module):
def __init__(self, name):
self.result_name = name
self.lstm = tf.keras.layers.LSTM(
@@ -1191,6 +1191,23 @@ class TBuilder:
def eval(self, a):
return self.lstm(a)
+ class SLSTM(tf.Module):
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ stateful=True,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
class GRU:
def __init__(self, name):
self.result_name = name
@@ -1256,3 +1273,22 @@ class TBuilder:
def eval(self, a):
return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
+
+ class CallOnce(tf.Module):
+ def __init__(self, name):
+ print(tf.__version__)
+ self.result_name = name
+ self.var = tf.Variable([1.0])
+
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(
+ shape=[
+ 1,
+ ],
+ dtype=tf.float32,
+ )
+ ]
+ )
+ def eval(self, a):
+ return self.var.assign([2.0])
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ec009c6..ffe373b 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402
get_tf_dtype,
get_shape_str,
) # noqa: E402
+
from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
# All of the supported frameworks
@@ -829,6 +830,15 @@ TF_OP_LIST = {
]
},
},
+ "lstm_stateful": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
"gru": {
"operands": (1, 0),
"build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
@@ -848,6 +858,17 @@ TF_OP_LIST = {
]
},
},
+ "callonce": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.float32],
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1,)],
+ },
+ },
"rfft2d": {
"operands": (1, 0),
"build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
@@ -1219,9 +1240,15 @@ def run_unit_test(
if "tflite" not in excluded_framework_list:
# Convert the model to TFLite flatbuffer
module = tf.Module()
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
- [concrete_function], module
- )
+
+ if op_name == "callonce" or op_name == "lstm_stateful":
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], fcn_node
+ )
+ else:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], module
+ )
converter.experimental_new_converter = True