aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-06-10 15:25:54 +0000
committerWon Jeon <won.jeon@arm.com>2023-07-05 21:20:18 +0000
commite2325d12ba6eebeb59d50e7a7ce578a8a32a03ed (patch)
tree24bed603251afc3ecc8b0fd5e1793a6278dda39a
parentedac6abb626993c1005feb48db497064a7d102d0 (diff)
downloadreference_model-e2325d12ba6eebeb59d50e7a7ce578a8a32a03ed.tar.gz
Add a parameter to tensor generation function to disable fuzzing
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Iff31b33b818a181371904915d5477a169513aa2e
-rw-r--r--verif/frameworks/tensor_gen.py6
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py31
2 files changed, 23 insertions, 14 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 170e5d8..d50bc74 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -91,7 +91,7 @@ class TGen:
return tf_placeholders, tf_consts
@staticmethod
- def tgBFuzz(op, shape, dtype, rng):
+ def tgBFuzz(op, shape, dtype, rng, fuzzed=[]):
# Build random tensor placeholder node args of a given shape, optionally
# fuzzing the arguments with random 1's to force broadcasting
@@ -105,12 +105,14 @@ class TGen:
tf_placeholders = []
tf_consts = []
for i in range(pl):
- if i == fuzz_arg:
+ if not fuzzed and i == fuzz_arg:
# Insert the broadcast in one dimension index
s_fuzz = list(shape)
s_fuzz[fuzz_idx] = 1
s_fuzz = tuple(s_fuzz)
i_shape = s_fuzz
+ # Record the fuzzed index.
+ fuzzed.append(i)
else:
i_shape = shape
tf_placeholders.append(
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 02ab8aa..12fff68 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,7 +22,6 @@ import tensorflow as tf # noqa: E402
from frameworks.write_test_json import write_test_json # noqa: E402
from frameworks.arg_gen import ArgGen # noqa: E402
from frameworks.tensor_gen import TGen # noqa: E402
-from frameworks.tensor_gen import ElemSignedness # noqa: E402
from frameworks.test_builder import TBuilder # noqa: E402
from frameworks.test_gen_utils import ( # noqa: E402
QuantType,
@@ -958,9 +957,16 @@ def run_unit_test(
# Get and seed a random number generator for this test
rng = np.random.default_rng(seed)
+ # For broadcast fuzzing, record the fuzzed index if fuzzing is already done.
+ fuzzed = []
+
# return placeholders=(str: name, np.array: value)
# consts=(str: name, np.array: value)
- placeholders, consts = tensor_gen_fcn(op, curr_shape, dtype, rng)
+ placeholders, consts = (
+ tensor_gen_fcn(op, curr_shape, dtype, rng, fuzzed)
+ if tensor_gen_fcn.__name__ == "tgBFuzz"
+ else tensor_gen_fcn(op, curr_shape, dtype, rng)
+ )
# if test doesn't have any placeholders/consts, terminated
if len(placeholders) == 0 and len(consts) == 0:
@@ -1157,18 +1163,19 @@ def run_unit_test(
if tflite_inference_dtype == tf.int16:
converter.target_spec.supported_ops = [flag]
+ # Generator function for integer quantization of TFLiteConverter
+ # which generates a few hundred input samples with the same order, type, and shape as the inputs,
+ # to calibrate/estimate the range of the floating-point inputs.
+ # For broadcast fuzzing tests, fuzzing needs to be disabled, otherwise, it causes a mismatch of
+ # tensor shapes of inputs.
def input_stats():
- ## Rsqrt can only handle positive numbers
- elem_signedness = ElemSignedness.ALL_RANGE
- if op_name == "rsqrt":
- elem_signedness = ElemSignedness.POSITIVE
-
for i in range(0, args.num_samples):
- a = [
- TGen.getRand(shape, tf.float32, rng, elem_signedness)
- for shape in placeholder_shapes
- ]
- yield a
+ placeholders, _ = (
+ tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng, fuzzed)
+ if tensor_gen_fcn == "tgBFuzz"
+ else tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng)
+ )
+ yield [s[1] for s in placeholders]
converter.representative_dataset = input_stats
converter.inference_input_type = tflite_inference_dtype