diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 90bda34..767989e 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -252,3 +252,25 @@ class TGen: tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng))) return tf_placeholders, tf_consts + + @staticmethod + def tgRecurrent(op, ifm_shape, dtype, rng): + # Require rank 3 shape for recurrent networks + if len(ifm_shape) != 3: + return [], [] + pl, const = op["operands"] + + tf_placeholders = [] + tf_consts = [] + + for i in range(pl): + tf_placeholders.append( + ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + for i in range(const): + tf_consts.append( + ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng)) + ) + + return tf_placeholders, tf_consts |