aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 3554e40..7b20cef 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -136,7 +136,11 @@ class TBuilder:
self.result_name = name
def eval(self, a, b):
- return tf.concat([a, b], self.axis, name=self.result_name)
+ return (
+ tf.concat([a, b], self.axis, name=self.result_name)
+ if a.shape != ()
+ else tf.stack([a, b], name=self.result_name)
+ )
class BitwiseAnd:
def __init__(self, name):
@@ -767,7 +771,11 @@ class TBuilder:
self.result_name = name
def eval(self, a, b, c, d):
- return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+ return (
+ tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+ if a.shape != ()
+ else tf.stack([a, b, c, d], name=self.result_name)
+ )
class Stack:
def __init__(self, axis, name):