aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 170e5d8..d50bc74 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -91,7 +91,7 @@ class TGen:
return tf_placeholders, tf_consts
@staticmethod
- def tgBFuzz(op, shape, dtype, rng):
+ def tgBFuzz(op, shape, dtype, rng, fuzzed=[]):
# Build random tensor placeholder node args of a given shape, optionally
# fuzzing the arguments with random 1's to force broadcasting
@@ -105,12 +105,14 @@ class TGen:
tf_placeholders = []
tf_consts = []
for i in range(pl):
- if i == fuzz_arg:
+ if not fuzzed and i == fuzz_arg:
# Insert the broadcast in one dimension index
s_fuzz = list(shape)
s_fuzz[fuzz_idx] = 1
s_fuzz = tuple(s_fuzz)
i_shape = s_fuzz
+ # Record the fuzzed index.
+ fuzzed.append(i)
else:
i_shape = shape
tf_placeholders.append(