diff options
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index e375a2a..5c25f8e 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -923,8 +923,9 @@ class TosaTestGen: if dtype == DType.BOOL: np_dt = np.bool return np.bool_(self.rng.choice(a=[False, True], size=shape)) + # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: - return np.int32(self.rng.integers(low=-8, high=8, size=shape)) + return np.int32(self.rng.integers(low=-7, high=8, size=shape)) elif dtype == DType.INT8: return np.int32(self.rng.integers(low=-128, high=128, size=shape)) elif dtype == DType.UINT8: @@ -988,8 +989,9 @@ class TosaTestGen: return self.rng.random() elif dtype == DType.BOOL: return self.rng.choice([False, True]) + # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: - low, high = (-8, 8) + low, high = (-7, 8) elif dtype == DType.INT8: low, high = (-128, 128) elif dtype == DType.INT16: @@ -1977,6 +1979,7 @@ class TosaTestGen: TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] TYPE_CONV2D = [ + [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], DType.FLOAT, |