diff options
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r-- | verif/frameworks/test_builder.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 3554e40..7b20cef 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -136,7 +136,11 @@ class TBuilder: self.result_name = name def eval(self, a, b): - return tf.concat([a, b], self.axis, name=self.result_name) + return ( + tf.concat([a, b], self.axis, name=self.result_name) + if a.shape != () + else tf.stack([a, b], name=self.result_name) + ) class BitwiseAnd: def __init__(self, name): @@ -767,7 +771,11 @@ class TBuilder: self.result_name = name def eval(self, a, b, c, d): - return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name) + return ( + tf.concat([a, b, c, d], axis=self.axis, name=self.result_name) + if a.shape != () + else tf.stack([a, b, c, d], name=self.result_name) + ) class Stack: def __init__(self, axis, name): |