aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-10-27 09:57:00 -0700
committerEric Kunze <eric.kunze@arm.com>2023-01-13 19:09:21 +0000
commit9e94af8f10f0a21a117b3bc7ea42004844fdc3bb (patch)
tree868ab73bb67d4827963a4b43f28d8a8a49f50307 /verif
parentdd8d9c251db0fece6453d86116052ad7f3e2d697 (diff)
downloadreference_model-9e94af8f10f0a21a117b3bc7ea42004844fdc3bb.tar.gz
Reference model update for control flow operators support
Rationale for making this change: - In the original design, for control flow operators like WhileOp, child blocks couldn't read the tensor variables (global consts) in the root level block, this patch added the machanism for child blocks to access their parent level block's tensors. - This change also relies on another serialization change on adding another layer of abtraction called Region: - Serialization patch: [region] Add TosaSerializationRegion to serialization_lib - Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib - This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator - Added the WhileOp related test cases: While, LSTM, GRU, RNN - Other related fixes Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab
Diffstat (limited to 'verif')
-rw-r--r--verif/frameworks/tensor_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py36
-rw-r--r--verif/generator/tosa_test_gen.py20
4 files changed, 150 insertions, 11 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index c7c5cd7..8870f41 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1164,3 +1164,82 @@ class TBuilder:
def eval(self, a):
return tf.bitwise.right_shift(a, self.shift, name=self.result_name)
+
+ class While:
+ def __init__(self, name):
+ self.result_name = name
+
+ def while_cond(self, x):
+ return tf.reduce_sum(x) < self.cap
+
+ def while_body(self, x):
+ return tf.add(x, tf.math.sigmoid(x))
+
+ def eval(self, a):
+ self.cap = tf.cast(
+ tf.constant(
+ 2.0,
+ shape=[
+ 1,
+ ],
+ ),
+ a.dtype,
+ )
+
+ result = tf.while_loop(
+ self.while_cond, self.while_body, [a], name=self.result_name
+ )
+
+ return result[0]
+
+ class LSTM:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class GRU:
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.GRU(
+ 2,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
+ class RNN:
+ def __init__(self, name):
+ self.result_name = name
+ basic_cell = tf.keras.layers.SimpleRNNCell(
+ units=2,
+ activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ )
+ self.rnn = tf.keras.layers.RNN(basic_cell, unroll=False)
+
+ def eval(self, a):
+ return self.rnn(a)
+
+ class FullyConnected:
+ def __init__(self, name):
+ self.result_name = name
+ self.dense = tf.keras.layers.Dense(2)
+
+ def eval(self, a):
+ return self.dense(a)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 760def6..26af5dd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -807,6 +807,42 @@ TF_OP_LIST = {
]
},
},
+ "while": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.While, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": list(TYPE_F),
+ },
+ },
+ "lstm": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.LSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "gru": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ # tf.int32
+ ]
+ },
+ },
+ "rnn": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RNN, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
}
# Shapes to be tested; default can be overwritten
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 515e8bb..d799eb0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import os
from copy import deepcopy
@@ -1845,7 +1845,7 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
- self.ser.startBasicBlock(then_block)
+ self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
@@ -1853,7 +1853,7 @@ class TosaTestGen:
then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
self.ser.addOutputTensor(then_tens)
- self.ser.startBasicBlock(else_block)
+ self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
else:
@@ -1865,7 +1865,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -1914,7 +1914,7 @@ class TosaTestGen:
assert False, f"No tests for DType: {a.dtype}"
for block, op in ((then_block, then_op), (else_block, else_op)):
- self.ser.startBasicBlock(block)
+ self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
and block == then_block
@@ -1948,7 +1948,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
cond=cond_tens,
):
return None
@@ -2005,7 +2005,8 @@ class TosaTestGen:
incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
# COND block (input: iter, output: cond_tens )
- self.ser.startBasicBlock(cond_block)
+ self.ser.addBasicBlock(cond_block)
+
if error_name == ErrorIf.InputListCondGraphMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2034,7 +2035,8 @@ class TosaTestGen:
# BODY block (input: a, acc, iter, output: a, acc, iter)
# Note that local intermediate tensors need to be declared here for the outputs
- self.ser.startBasicBlock(body_block)
+ self.ser.addBasicBlock(body_block)
+
if error_name == ErrorIf.InputListBodyGraphInputMismatch:
self.ser.addInputTensor(incorrect_iter)
self.ser.addInputTensor(a)
@@ -2068,7 +2070,7 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks,
+ basicBlocks=self.ser.currRegion.basicBlocks,
):
return None