aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tensor_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r--verif/frameworks/tensor_gen.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index d0c0a0b..4370215 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -41,8 +41,12 @@ class TGen:
RAND_SHIFT_FACTOR = 0.5
if dtype == tf.float32:
- return np.float32(
- (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+ return (
+ np.float32(
+ (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+ )
+ if shape != ()
+ else np.float32(rng.random())
)
if dtype == tf.float16:
return np.float16(