diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index c534a58..f8d50a8 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -1,5 +1,7 @@ # Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import enum + import numpy as np import tensorflow as tf @@ -17,6 +19,12 @@ RAND_INT_MIN = -128 RAND_INT_MAX = 128 +class ElemSignedness(enum.Enum): + ALL_RANGE = 1 + POSITIVE = 2 + NEGATIVE = 3 + + class TGen: """A collection of functions to build tensor value arguments for an operator""" @@ -24,7 +32,14 @@ class TGen: pass @staticmethod - def getRand(shape, dtype, rng): + def getRand(shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE): + if elem_signedness == ElemSignedness.POSITIVE: + RAND_SHIFT_FACTOR = 0 + elif elem_signedness == ElemSignedness.NEGATIVE: + RAND_SHIFT_FACTOR = 1 + else: + RAND_SHIFT_FACTOR = 0.5 + if dtype == tf.float32: return np.float32( (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR @@ -45,7 +60,11 @@ class TGen: raise Exception("Unsupported type: {}".format(dtype)) @staticmethod - def tgBasic(op, shape, dtype, rng): + def tgBasicPositive(op, shape, dtype, rng, elem_signedness=ElemSignedness.POSITIVE): + return TGen.tgBasic(op, shape, dtype, rng, elem_signedness) + + @staticmethod + def tgBasic(op, shape, dtype, rng, elem_signedness=ElemSignedness.ALL_RANGE): # Build random tensor placeholder node args of a given shape pl, const = op["operands"] @@ -54,11 +73,16 @@ class TGen: for i in range(pl): tf_placeholders.append( - ("placeholder_{}".format(i), TGen.getRand(shape, dtype, rng)) + ( + "placeholder_{}".format(i), + TGen.getRand(shape, dtype, rng, elem_signedness), + ) ) for i in range(const): - tf_consts.append(("const_{}".format(i), TGen.getRand(shape, dtype, rng))) + tf_consts.append( + ("const_{}".format(i), TGen.getRand(shape, dtype, rng, elem_signedness)) + ) return tf_placeholders, tf_consts |