aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index e375a2a..5c25f8e 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -923,8 +923,9 @@ class TosaTestGen:
if dtype == DType.BOOL:
np_dt = np.bool
return np.bool_(self.rng.choice(a=[False, True], size=shape))
+ # TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
- return np.int32(self.rng.integers(low=-8, high=8, size=shape))
+ return np.int32(self.rng.integers(low=-7, high=8, size=shape))
elif dtype == DType.INT8:
return np.int32(self.rng.integers(low=-128, high=128, size=shape))
elif dtype == DType.UINT8:
@@ -988,8 +989,9 @@ class TosaTestGen:
return self.rng.random()
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
+ # TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
- low, high = (-8, 8)
+ low, high = (-7, 8)
elif dtype == DType.INT8:
low, high = (-128, 128)
elif dtype == DType.INT16:
@@ -1977,6 +1979,7 @@ class TosaTestGen:
TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
TYPE_CONV2D = [
+ [DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
DType.FLOAT,