diff options
Diffstat (limited to 'verif/frameworks')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 24 | ||||
-rw-r--r-- | verif/frameworks/test_builder.py | 81 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 36 |
3 files changed, 139 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 90bda34..767989e 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -252,3 +252,25 @@ class TGen: tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng))) return tf_placeholders, tf_consts + + @staticmethod + def tgRecurrent(op, ifm_shape, dtype, rng): + # Require rank 3 shape for recurrent networks + if len(ifm_shape) != 3: + return [], [] + pl, const = op["operands"] + + tf_placeholders = [] + tf_consts = [] + + for i in range(pl): + tf_placeholders.append( + ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + for i in range(const): + tf_consts.append( + ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + return tf_placeholders, tf_consts diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index c7c5cd7..8870f41 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -1164,3 +1164,82 @@ class TBuilder: def eval(self, a): return tf.bitwise.right_shift(a, self.shift, name=self.result_name) + + class While: + def __init__(self, name): + self.result_name = name + + def while_cond(self, x): + return tf.reduce_sum(x) < self.cap + + def while_body(self, x): + return tf.add(x, tf.math.sigmoid(x)) + + def eval(self, a): + self.cap = tf.cast( + tf.constant( + 2.0, + shape=[ + 1, + ], + ), + a.dtype, + ) + + result = tf.while_loop( + self.while_cond, self.while_body, [a], name=self.result_name + ) + + return result[0] + + class LSTM: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class GRU: + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.GRU( + 2, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + + class RNN: + def __init__(self, name): + self.result_name = name + basic_cell = tf.keras.layers.SimpleRNNCell( + units=2, + activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + ) + self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False) + + def eval(self, a): + return self.rnn(a) + + class FullyConnected: + def __init__(self, name): + self.result_name = name + self.dense = tf.keras.layers.Dense(2) + + def eval(self, a): + return self.dense(a) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 760def6..26af5dd 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -807,6 +807,42 @@ TF_OP_LIST = { ] }, }, + "while": { + "operands": (1, 0), + "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": list(TYPE_F), + }, + }, + "lstm": { + "operands": (1, 0), + "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + # tf.int32 + ] + }, + }, + "gru": { + "operands": (1, 0), + "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + # tf.int32 + ] + }, + }, + "rnn": { + "operands": (1, 0), + "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, } # Shapes to be tested; default can be overwritten |