diff options
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 93bdfe0..0741686 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -22,6 +22,7 @@ import tensorflow as tf # noqa: E402 from frameworks.write_test_json import write_test_json # noqa: E402 from frameworks.arg_gen import ArgGen # noqa: E402 from frameworks.tensor_gen import TGen # noqa: E402 +from frameworks.tensor_gen import ElemSignedness # noqa: E402 from frameworks.test_builder import TBuilder # noqa: E402 from frameworks.test_gen_utils import ( # noqa: E402 QuantType, @@ -303,8 +304,11 @@ TF_OP_LIST = { }, "rsqrt": { "operands": (1, 0), - "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone), - "types": TYPE_F, + "build_fcn": (TBuilder.Rsqrt, TGen.tgBasicPositive, ArgGen.agNone), + "types": { + "tf": TYPE_F, + "tflite": list(TYPE_F + [QuantType.ALL_I8]), + }, }, "sign": { "operands": (1, 0), @@ -1121,9 +1125,14 @@ def run_unit_test( converter.target_spec.supported_ops = [flag] def input_stats(): + ## Rsqrt can only handle positive numbers + elem_signedness = ElemSignedness.ALL_RANGE + if op_name == "rsqrt": + elem_signedness = ElemSignedness.POSITIVE + for i in range(0, args.num_samples): a = [ - TGen.getRand(shape, tf.float32, rng) + TGen.getRand(shape, tf.float32, rng, elem_signedness) for shape in placeholder_shapes ] yield a |