aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 3e5aee8..9dc3199 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -176,8 +176,11 @@ class TosaTestGen:
# Inclusive range: low <= range <= high
return (rng[0], rng[1] - 1)
- def getRandTensor(self, shape, dtype):
- low, high = self.getDTypeRange(dtype)
+ def getRandTensor(self, shape, dtype, data_range=None):
+ if data_range is None:
+ low, high = self.getDTypeRange(dtype)
+ else:
+ low, high = data_range
if dtype == DType.BOOL:
return np.bool_(self.rng.choice(a=[False, True], size=shape))