aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
Diffstat (limited to 'verif')
-rw-r--r--verif/frameworks/tensor_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py36
-rw-r--r--verif/generator/tosa_test_gen.py20
4 files changed, 150 insertions, 11 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index c7c5cd7..8870f41 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1164,3 +1164,82 @@ class TBuilder:
def eval(self, a):
return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
+
+ class While:
+ def __init__(self, name):
+ self.result_name = name
+
+ def while_cond(self, x):
+ return tf.reduce_sum(x) < self.cap
+
+ def while_body(self, x):
+ return tf.add(x, tf.math.sigmoid(x))
+
+ def eval(self, a):
+ self.cap = tf.cast(
+ tf.constant(
+ 2.0,
+ shape=[
+ 1,
+ ],
+ ),
+ a.dtype,
+ )
+
+ result = tf.while_loop(
+ self.while_cond, self.while_body, [a], name=self.result_name
+ )
+
+ return result[0]
+
+ class LSTM:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class GRU:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.GRU(
+ 2,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class RNN:
+ def __init__(self, name):
+ self.result_name = name
+ basic_cell = tf.keras.layers.SimpleRNNCell(
+ units=2,
+ activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ )
+ self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
+
+ def eval(self, a):
+ return self.rnn(a)
+
+ class FullyConnected:
+ def __init__(self, name):
+ self.result_name = name
+ self.dense = tf.keras.layers.Dense(2)
+
+ def eval(self, a):
+ return self.dense(a)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 760def6..26af5dd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -807,6 +807,42 @@ TF_OP_LIST = {
]
},
},
+ "while": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": list(TYPE_F),
+ },
+ },
+ "lstm": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "gru": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "rnn": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
}
# Shapes to be tested; default can be overwritten
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 515e8bb..d799eb0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
@@ -1845,7 +1845,7 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
- self.ser.startBasicBlock(then_block)
+ self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
@@ -1853,7 +1853,7 @@ class TosaTestGen:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
- self.ser.startBasicBlock(else_block)
+ self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
@@ -1865,7 +1865,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -1914,7 +1914,7 @@ class TosaTestGen:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
- self.ser.startBasicBlock(block)
+ self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
@@ -1948,7 +1948,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -2005,7 +2005,8 @@ class TosaTestGen:
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
- self.ser.startBasicBlock(cond_block)
+ self.ser.addBasicBlock(cond_block)
+
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2034,7 +2035,8 @@ class TosaTestGen:
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
- self.ser.startBasicBlock(body_block)
+ self.ser.addBasicBlock(body_block)
+
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2068,7 +2070,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
):
return None