diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index f8d50a8..60e17ce 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -56,6 +56,10 @@ class TGen: return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape)) if dtype == tf.bool: return np.bool_(rng.choice(a=[False, True], size=shape)) + if dtype == tf.complex64: + return TGen.getRand(shape, np.float32, rng) + 1j * TGen.getRand( + shape, np.float32, rng + ) raise Exception("Unsupported type: {}".format(dtype)) @@ -305,5 +309,13 @@ class TGen: if len(shape) != 3: return [], [] - tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))] - return tf_placeholders, [] + return TGen.tgBasic(op, shape, dtype, rng) + + @staticmethod + def tgComplexComponents(op, shape, dtype, rng): + # Temporarily require up to rank 3 shape, due to + # slice maximum rank limitiation. + if len(shape) > 3: + return [], [] + + return TGen.tgBasic(op, shape, dtype, rng) |