diff options
Diffstat (limited to 'verif/frameworks/tensor_gen.py')
-rw-r--r-- | verif/frameworks/tensor_gen.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 60e17ce..170e5d8 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -319,3 +319,41 @@ class TGen: return [], [] return TGen.tgBasic(op, shape, dtype, rng) + + @staticmethod + def tgBroadcastTo(op, shape, dtype, rng): + + pl, const = op["operands"] + + assert pl == 1 + assert const == 1 + + tf_placeholders = [] + tf_consts = [] + + shape_list = list(shape) + t_shape_list = [] + s_shape_list = [] + for i in range(len(shape)): + dim = shape_list[i] + if rng.integers(0, 1) == 0: + # append dim in s_shape_list, and 1 in t_shape_list unless it is still empty + s_shape_list.append(dim) + if len(t_shape_list) > 0: + t_shape_list.append(1) + else: + # append 1 in s_shape_list, and dim in t_shape_list + s_shape_list.append(1) + t_shape_list.append(dim) + + # if t_shape_list is empty, then insert 1 + if len(t_shape_list) == 0: + t_shape_list.append(1) + + tf_placeholders.append( + ("placeholder_0", TGen.getRand(tuple(t_shape_list), dtype, rng)) + ) + + tf_consts.append(("shape", tuple(s_shape_list))) + + return tf_placeholders, tf_consts |