aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index c534a58..f8d50a8 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -1,5 +1,7 @@
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import enum
+
import numpy as np
import tensorflow as tf
@@ -17,6 +19,12 @@ RAND_INT_MIN = -128
RAND_INT_MAX = 128
+class ElemSignedness(enum.Enum):
+ ALL_RANGE = 1
+ POSITIVE = 2
+ NEGATIVE = 3
+
+
class TGen:
"""A collection of functions to build tensor value arguments for an operator"""
@@ -24,7 +32,14 @@ class TGen:
pass
@staticmethod
- def getRand(shape, dtype, rng):
+ def getRand(shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
+ if elem_signedness == ElemSignedness.POSITIVE:
+ RAND_SHIFT_FACTOR = 0
+ elif elem_signedness == ElemSignedness.NEGATIVE:
+ RAND_SHIFT_FACTOR = 1
+ else:
+ RAND_SHIFT_FACTOR = 0.5
+
if dtype == tf.float32:
return np.float32(
(rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
@@ -45,7 +60,11 @@ class TGen:
raise Exception("Unsupported type: {}".format(dtype))
@staticmethod
- def tgBasic(op, shape, dtype, rng):
+ def tgBasicPositive(op, shape, dtype, rng, elem_signedness=ElemSignedness.POSITIVE):
+ return TGen.tgBasic(op, shape, dtype, rng, elem_signedness)
+
+ @staticmethod
+ def tgBasic(op, shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE):
# Build random tensor placeholder node args of a given shape
pl, const = op["operands"]
@@ -54,11 +73,16 @@ class TGen:
for i in range(pl):
tf_placeholders.append(
- ("placeholder_{}".format(i), TGen.getRand(shape, dtype, rng))
+ (
+ "placeholder_{}".format(i),
+ TGen.getRand(shape, dtype, rng, elem_signedness),
+ )
)
for i in range(const):
- tf_consts.append(("const_{}".format(i), TGen.getRand(shape, dtype, rng)))
+ tf_consts.append(
+ ("const_{}".format(i), TGen.getRand(shape, dtype, rng, elem_signedness))
+ )
return tf_placeholders, tf_consts