diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-03-03 11:33:51 -0800 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-03-30 15:02:26 +0000 |
commit | b1f25015d4be6c9b8cd399d7e14fea98cd2f01f5 (patch) | |
tree | a47887b4ec78783c280663d82fbe0dc67093c619 /verif/frameworks/tosa_verif_framework_generator.py | |
parent | 4ca8f64c601137095cb1780f1b86bc305f4db0bc (diff) | |
download | reference_model-b1f25015d4be6c9b8cd399d7e14fea98cd2f01f5.tar.gz |
Add positive/negative random number generator for Rsqrt
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I1e9e97ead447295e1252785106931b261df7bcea
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 93bdfe0..0741686 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -22,6 +22,7 @@ import tensorflow as tf # noqa: E402 from frameworks.write_test_json import write_test_json # noqa: E402 from frameworks.arg_gen import ArgGen # noqa: E402 from frameworks.tensor_gen import TGen # noqa: E402 +from frameworks.tensor_gen import ElemSignedness # noqa: E402 from frameworks.test_builder import TBuilder # noqa: E402 from frameworks.test_gen_utils import ( # noqa: E402 QuantType, @@ -303,8 +304,11 @@ TF_OP_LIST = { }, "rsqrt": { "operands": (1, 0), - "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone), - "types": TYPE_F, + "build_fcn": (TBuilder.Rsqrt, TGen.tgBasicPositive, ArgGen.agNone), + "types": { + "tf": TYPE_F, + "tflite": list(TYPE_F + [QuantType.ALL_I8]), + }, }, "sign": { "operands": (1, 0), @@ -1121,9 +1125,14 @@ def run_unit_test( converter.target_spec.supported_ops = [flag] def input_stats(): + ## Rsqrt can only handle positive numbers + elem_signedness = ElemSignedness.ALL_RANGE + if op_name == "rsqrt": + elem_signedness = ElemSignedness.POSITIVE + for i in range(0, args.num_samples): a = [ - TGen.getRand(shape, tf.float32, rng) + TGen.getRand(shape, tf.float32, rng, elem_signedness) for shape in placeholder_shapes ] yield a |