aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tosa_verif_framework_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 93bdfe0..0741686 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,6 +22,7 @@ import tensorflow as tf # noqa: E402
from frameworks.write_test_json import write_test_json # noqa: E402
from frameworks.arg_gen import ArgGen # noqa: E402
from frameworks.tensor_gen import TGen # noqa: E402
+from frameworks.tensor_gen import ElemSignedness # noqa: E402
from frameworks.test_builder import TBuilder # noqa: E402
from frameworks.test_gen_utils import ( # noqa: E402
QuantType,
@@ -303,8 +304,11 @@ TF_OP_LIST = {
},
"rsqrt": {
"operands": (1, 0),
- "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
- "types": TYPE_F,
+ "build_fcn": (TBuilder.Rsqrt, TGen.tgBasicPositive, ArgGen.agNone),
+ "types": {
+ "tf": TYPE_F,
+ "tflite": list(TYPE_F + [QuantType.ALL_I8]),
+ },
},
"sign": {
"operands": (1, 0),
@@ -1121,9 +1125,14 @@ def run_unit_test(
converter.target_spec.supported_ops = [flag]
def input_stats():
+ ## Rsqrt can only handle positive numbers
+ elem_signedness = ElemSignedness.ALL_RANGE
+ if op_name == "rsqrt":
+ elem_signedness = ElemSignedness.POSITIVE
+
for i in range(0, args.num_samples):
a = [
- TGen.getRand(shape, tf.float32, rng)
+ TGen.getRand(shape, tf.float32, rng, elem_signedness)
for shape in placeholder_shapes
]
yield a