aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-03-03 11:33:51 -0800
committerEric Kunze <eric.kunze@arm.com>2023-03-30 15:02:26 +0000
commitb1f25015d4be6c9b8cd399d7e14fea98cd2f01f5 (patch)
treea47887b4ec78783c280663d82fbe0dc67093c619
parent4ca8f64c601137095cb1780f1b86bc305f4db0bc (diff)
downloadreference_model-b1f25015d4be6c9b8cd399d7e14fea98cd2f01f5.tar.gz
Add positive/negative random number generator for Rsqrt
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I1e9e97ead447295e1252785106931b261df7bcea
-rw-r--r--verif/frameworks/tensor_gen.py32
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py15
2 files changed, 40 insertions, 7 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index c534a58..f8d50a8 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,5 +1,7 @@
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import enum
+
import numpy as np
import tensorflow as tf
@@ -17,6 +19,12 @@ RAND_INT_MIN = -128
RAND_INT_MAX = 128
+class ElemSignedness(enum.Enum):
+ ALL_RANGE = 1
+ POSITIVE = 2
+ NEGATIVE = 3
+
+
class TGen:
"""A collection of functions to build tensor value arguments for an operator"""
@@ -24,7 +32,14 @@ class TGen:
pass
@staticmethod
- def getRand(shape, dtype, rng):
+ def getRand(shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
+ if elem_signedness == ElemSignedness.POSITIVE:
+ RAND_SHIFT_FACTOR = 0
+ elif elem_signedness == ElemSignedness.NEGATIVE:
+ RAND_SHIFT_FACTOR = 1
+ else:
+ RAND_SHIFT_FACTOR = 0.5
+
if dtype == tf.float32:
return np.float32(
(rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
@@ -45,7 +60,11 @@ class TGen:
raise Exception("Unsupported type: {}".format(dtype))
@staticmethod
- def tgBasic(op, shape, dtype, rng):
+ def tgBasicPositive(op, shape, dtype, rng, elem_signedness=ElemSignedness.POSITIVE):
+ return TGen.tgBasic(op, shape, dtype, rng, elem_signedness)
+
+ @staticmethod
+ def tgBasic(op, shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
# Build random tensor placeholder node args of a given shape
pl, const = op["operands"]
@@ -54,11 +73,16 @@ class TGen:
for i in range(pl):
tf_placeholders.append(
- ("placeholder_{}".format(i), TGen.getRand(shape, dtype, rng))
+ (
+ "placeholder_{}".format(i),
+ TGen.getRand(shape, dtype, rng, elem_signedness),
+ )
)
for i in range(const):
- tf_consts.append(("const_{}".format(i), TGen.getRand(shape, dtype, rng)))
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(shape, dtype, rng, elem_signedness))
+ )
return tf_placeholders, tf_consts
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 93bdfe0..0741686 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -22,6 +22,7 @@ import tensorflow as tf # noqa: E402
from frameworks.write_test_json import write_test_json # noqa: E402
from frameworks.arg_gen import ArgGen # noqa: E402
from frameworks.tensor_gen import TGen # noqa: E402
+from frameworks.tensor_gen import ElemSignedness # noqa: E402
from frameworks.test_builder import TBuilder # noqa: E402
from frameworks.test_gen_utils import ( # noqa: E402
QuantType,
@@ -303,8 +304,11 @@ TF_OP_LIST = {
},
"rsqrt": {
"operands": (1, 0),
- "build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
- "types": TYPE_F,
+ "build_fcn": (TBuilder.Rsqrt, TGen.tgBasicPositive, ArgGen.agNone),
+ "types": {
+ "tf": TYPE_F,
+ "tflite": list(TYPE_F + [QuantType.ALL_I8]),
+ },
},
"sign": {
"operands": (1, 0),
@@ -1121,9 +1125,14 @@ def run_unit_test(
converter.target_spec.supported_ops = [flag]
def input_stats():
+ ## Rsqrt can only handle positive numbers
+ elem_signedness = ElemSignedness.ALL_RANGE
+ if op_name == "rsqrt":
+ elem_signedness = ElemSignedness.POSITIVE
+
for i in range(0, args.num_samples):
a = [
- TGen.getRand(shape, tf.float32, rng)
+ TGen.getRand(shape, tf.float32, rng, elem_signedness)
for shape in placeholder_shapes
]
yield a