diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 3e5aee8..9dc3199 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -176,8 +176,11 @@ class TosaTestGen: # Inclusive range: low <= range <= high return (rng[0], rng[1] - 1) - def getRandTensor(self, shape, dtype): - low, high = self.getDTypeRange(dtype) + def getRandTensor(self, shape, dtype, data_range=None): + if data_range is None: + low, high = self.getDTypeRange(dtype) + else: + low, high = data_range if dtype == DType.BOOL: return np.bool_(self.rng.choice(a=[False, True], size=shape)) |