aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks')
-rw-r--r--verif/frameworks/tensor_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py36
3 files changed, 139 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index c7c5cd7..8870f41 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1164,3 +1164,82 @@ class TBuilder:
def eval(self, a):
return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
+
+ class While:
+ def __init__(self, name):
+ self.result_name = name
+
+ def while_cond(self, x):
+ return tf.reduce_sum(x) < self.cap
+
+ def while_body(self, x):
+ return tf.add(x, tf.math.sigmoid(x))
+
+ def eval(self, a):
+ self.cap = tf.cast(
+ tf.constant(
+ 2.0,
+ shape=[
+ 1,
+ ],
+ ),
+ a.dtype,
+ )
+
+ result = tf.while_loop(
+ self.while_cond, self.while_body, [a], name=self.result_name
+ )
+
+ return result[0]
+
+ class LSTM:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class GRU:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.GRU(
+ 2,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class RNN:
+ def __init__(self, name):
+ self.result_name = name
+ basic_cell = tf.keras.layers.SimpleRNNCell(
+ units=2,
+ activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ )
+ self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
+
+ def eval(self, a):
+ return self.rnn(a)
+
+ class FullyConnected:
+ def __init__(self, name):
+ self.result_name = name
+ self.dense = tf.keras.layers.Dense(2)
+
+ def eval(self, a):
+ return self.dense(a)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 760def6..26af5dd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -807,6 +807,42 @@ TF_OP_LIST = {
]
},
},
+ "while": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": list(TYPE_F),
+ },
+ },
+ "lstm": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "gru": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "rnn": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
}
# Shapes to be tested; default can be overwritten