aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index d50bc74..d0c0a0b 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -91,7 +91,7 @@ class TGen:
return tf_placeholders, tf_consts
@staticmethod
- def tgBFuzz(op, shape, dtype, rng, fuzzed=[]):
+ def tgBFuzz(op, shape, dtype, rng, for_tflite_converter=True):
# Build random tensor placeholder node args of a given shape, optionally
# fuzzing the arguments with random 1's to force broadcasting
@@ -99,22 +99,23 @@ class TGen:
assert const == 0
- fuzz_arg = rng.integers(0, pl + const)
- fuzz_idx = rng.integers(0, len(shape))
+ if not for_tflite_converter:
+ fuzz_arg = rng.integers(0, pl + const)
+ fuzz_idx = rng.integers(0, len(shape))
tf_placeholders = []
tf_consts = []
+
for i in range(pl):
- if not fuzzed and i == fuzz_arg:
+ if not for_tflite_converter and i == fuzz_arg:
# Insert the broadcast in one dimension index
s_fuzz = list(shape)
s_fuzz[fuzz_idx] = 1
s_fuzz = tuple(s_fuzz)
i_shape = s_fuzz
- # Record the fuzzed index.
- fuzzed.append(i)
else:
i_shape = shape
+
tf_placeholders.append(
("placeholder_{}".format(i), TGen.getRand(i_shape, dtype, rng))
)