diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index d50bc74..d0c0a0b 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -91,7 +91,7 @@ class TGen: return tf_placeholders, tf_consts @staticmethod - def tgBFuzz(op, shape, dtype, rng, fuzzed=[]): + def tgBFuzz(op, shape, dtype, rng, for_tflite_converter=True): # Build random tensor placeholder node args of a given shape, optionally # fuzzing the arguments with random 1's to force broadcasting @@ -99,22 +99,23 @@ class TGen: assert const == 0 - fuzz_arg = rng.integers(0, pl + const) - fuzz_idx = rng.integers(0, len(shape)) + if not for_tflite_converter: + fuzz_arg = rng.integers(0, pl + const) + fuzz_idx = rng.integers(0, len(shape)) tf_placeholders = [] tf_consts = [] + for i in range(pl): - if not fuzzed and i == fuzz_arg: + if not for_tflite_converter and i == fuzz_arg: # Insert the broadcast in one dimension index s_fuzz = list(shape) s_fuzz[fuzz_idx] = 1 s_fuzz = tuple(s_fuzz) i_shape = s_fuzz - # Record the fuzzed index. - fuzzed.append(i) else: i_shape = shape + tf_placeholders.append( ("placeholder_{}".format(i), TGen.getRand(i_shape, dtype, rng)) ) |