aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index f8d50a8..60e17ce 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -56,6 +56,10 @@ class TGen:
return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape))
if dtype == tf.bool:
return np.bool_(rng.choice(a=[False, True], size=shape))
+ if dtype == tf.complex64:
+ return TGen.getRand(shape, np.float32, rng) + 1j * TGen.getRand(
+ shape, np.float32, rng
+ )
raise Exception("Unsupported type: {}".format(dtype))
@@ -305,5 +309,13 @@ class TGen:
if len(shape) != 3:
return [], []
- tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
- return tf_placeholders, []
+ return TGen.tgBasic(op, shape, dtype, rng)
+
+ @staticmethod
+ def tgComplexComponents(op, shape, dtype, rng):
+ # Temporarily require up to rank 3 shape, due to
+ # slice maximum rank limitiation.
+ if len(shape) > 3:
+ return [], []
+
+ return TGen.tgBasic(op, shape, dtype, rng)