diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index d0c0a0b..4370215 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -41,8 +41,12 @@ class TGen: RAND_SHIFT_FACTOR = 0.5 if dtype == tf.float32: - return np.float32( - (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR + return ( + np.float32( + (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR + ) + if shape != () + else np.float32(rng.random()) ) if dtype == tf.float16: return np.float16( |