diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/frameworks/arg_gen.py | 4 | ||||
-rw-r--r-- | verif/frameworks/tensor_gen.py | 8 | ||||
-rw-r--r-- | verif/frameworks/test_builder.py | 12 | ||||
-rw-r--r-- | verif/frameworks/test_gen_utils.py | 3 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 12 |
5 files changed, 34 insertions, 5 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index a25c205..c385274 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -45,6 +45,10 @@ class ArgGen: @staticmethod def agAxes(op, shapes, rng): axes = [] + if shapes == (): + axes.append(["_axis_0", [0]]) + return axes + for i in range(-len(shapes), len(shapes), 1): if i >= 0: axes.append(["_axis_{}".format(i), [i]]) diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index d0c0a0b..4370215 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -41,8 +41,12 @@ class TGen: RAND_SHIFT_FACTOR = 0.5 if dtype == tf.float32: - return np.float32( - (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR + return ( + np.float32( + (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR + ) + if shape != () + else np.float32(rng.random()) ) if dtype == tf.float16: return np.float16( diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 3554e40..7b20cef 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -136,7 +136,11 @@ class TBuilder: self.result_name = name def eval(self, a, b): - return tf.concat([a, b], self.axis, name=self.result_name) + return ( + tf.concat([a, b], self.axis, name=self.result_name) + if a.shape != () + else tf.stack([a, b], name=self.result_name) + ) class BitwiseAnd: def __init__(self, name): @@ -767,7 +771,11 @@ class TBuilder: self.result_name = name def eval(self, a, b, c, d): - return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name) + return ( + tf.concat([a, b, c, d], axis=self.axis, name=self.result_name) + if a.shape != () + else tf.stack([a, b, c, d], name=self.result_name) + ) class Stack: def __init__(self, axis, name): diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py index 6a59848..f31ac63 100644 --- a/verif/frameworks/test_gen_utils.py +++ b/verif/frameworks/test_gen_utils.py @@ -9,6 +9,9 @@ import tensorflow as tf # Get a string name for a given shape def get_shape_str(shape, dtype): shape_name = None + if len(shape) == 0: + shape_name = "0" + for dim in shape: shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index ffe373b..9d666ab 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -174,6 +174,11 @@ TF_OP_LIST = { "operands": (2, 0), "build_fcn": (TBuilder.Concat, TGen.tgBasic, ArgGen.agAxes), "types": TYPE_FI, + "rank": (0, 4), + "custom_shapes": { + "custom_shape_only": False, + "shape_list": [()], + }, }, "bitwise_and": { "operands": (2, 0), @@ -635,6 +640,11 @@ TF_OP_LIST = { "operands": (4, 0), "build_fcn": (TBuilder.Concatv2, TGen.tgBasic, ArgGen.agAxes), "types": TYPE_FI, + "rank": (0, 4), + "custom_shapes": { + "custom_shape_only": False, + "shape_list": [()], + }, }, "stack": { "operands": (4, 0), @@ -1473,7 +1483,7 @@ def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_ shape_list = custom_shapes["shape_list"] else: shape_list = shape_list.copy() - shape_list.append(custom_shapes["shape_list"]) + shape_list.extend(custom_shapes["shape_list"]) except KeyError: pass |