diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 170e5d8..d50bc74 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -91,7 +91,7 @@ class TGen: return tf_placeholders, tf_consts @staticmethod - def tgBFuzz(op, shape, dtype, rng): + def tgBFuzz(op, shape, dtype, rng, fuzzed=[]): # Build random tensor placeholder node args of a given shape, optionally # fuzzing the arguments with random 1's to force broadcasting @@ -105,12 +105,14 @@ class TGen: tf_placeholders = [] tf_consts = [] for i in range(pl): - if i == fuzz_arg: + if not fuzzed and i == fuzz_arg: # Insert the broadcast in one dimension index s_fuzz = list(shape) s_fuzz[fuzz_idx] = 1 s_fuzz = tuple(s_fuzz) i_shape = s_fuzz + # Record the fuzzed index. + fuzzed.append(i) else: i_shape = shape tf_placeholders.append( |