aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2022-11-10 13:54:28 -0800
committerEric Kunze <eric.kunze@arm.com>2023-01-27 00:50:52 +0000
commitd713a4d841ad09a9466e43aa1bf0be09fe54ea22 (patch)
tree065fd1787bdee724d22e53af143c143fe96d269a
parentfac5c310df50ed36874f53ea2b25b55a57eaec51 (diff)
downloadreference_model-d713a4d841ad09a9466e43aa1bf0be09fe54ea22.tar.gz
Add framework test for math.sign
The result comparison between Tensorflow runtime and the reference model hasn't been checked as the sign operator is not supported by the native TFLite runtime. That said, since the generated tosa ops for tf.sign and tfl.sign is identical, the correctness presumably can be proved by the result from tf.sign. Change-Id: I72eb415df7fb6ca4dc9103f9ddc7104b0ba39234 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--verif/frameworks/test_builder.py7
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py7
2 files changed, 14 insertions, 0 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 6e7b6a5..e86d5fe 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -317,6 +317,13 @@ class TBuilder:
def eval(self, a):
return tf.math.rsqrt(a, name=self.result_name)
+ class Sign:
+ def __init__(self, name):
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.math.sign(a, name=self.result_name)
+
class Sigmoid:
def __init__(self, name):
self.result_name = name
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 36ddda5..742f991 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -298,6 +298,13 @@ TF_OP_LIST = {
"build_fcn": (TBuilder.Rsqrt, TGen.tgBasic, ArgGen.agNone),
"types": TYPE_F,
},
+ "sign": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.Sign, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tf": TYPE_F,
+ },
+ },
"sigmoid": {
"operands": (1, 0),
"build_fcn": (TBuilder.Sigmoid, TGen.tgBasic, ArgGen.agNone),