aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 90bda34..767989e 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -252,3 +252,25 @@ class TGen:
tf_placeholders.append(("placeholder_2", TGen.getRand(shape, dtype, rng)))
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRecurrent(op, ifm_shape, dtype, rng):
+ # Require rank 3 shape for recurrent networks
+ if len(ifm_shape) != 3:
+ return [], []
+ pl, const = op["operands"]
+
+ tf_placeholders = []
+ tf_consts = []
+
+ for i in range(pl):
+ tf_placeholders.append(
+ ("placeholder_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ for i in range(const):
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(ifm_shape, dtype, rng))
+ )
+
+ return tf_placeholders, tf_consts