From fe36fa9f38824d03250393488fe468b7dacc72ed Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 1 Jun 2023 21:45:12 +0000 Subject: Add tf broadcast-to testing This patch adds BoradcastTo Op to the tf tests. Did not add tflite testing because the tf.lite.TFLiteConverter converts tf.broadcast-to to tfl.mul by 1. Signed-off-by: Tai Ly Change-Id: Icd372e619c318121c19eca87d5716bcd5fbbbb23 --- verif/frameworks/tensor_gen.py | 38 ++++++++++++++++++++++ verif/frameworks/test_builder.py | 8 +++++ verif/frameworks/tosa_verif_framework_generator.py | 7 ++++ 3 files changed, 53 insertions(+) (limited to 'verif/frameworks') diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 60e17ce..170e5d8 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -319,3 +319,41 @@ class TGen: return [], [] return TGen.tgBasic(op, shape, dtype, rng) + + @staticmethod + def tgBroadcastTo(op, shape, dtype, rng): + + pl, const = op["operands"] + + assert pl == 1 + assert const == 1 + + tf_placeholders = [] + tf_consts = [] + + shape_list = list(shape) + t_shape_list = [] + s_shape_list = [] + for i in range(len(shape)): + dim = shape_list[i] + if rng.integers(0, 1) == 0: + # append dim in s_shape_list, and 1 in t_shape_list unless it is still empty + s_shape_list.append(dim) + if len(t_shape_list) > 0: + t_shape_list.append(1) + else: + # append 1 in s_shape_list, and dim in t_shape_list + s_shape_list.append(1) + t_shape_list.append(dim) + + # if t_shape_list is empty, then insert 1 + if len(t_shape_list) == 0: + t_shape_list.append(1) + + tf_placeholders.append( + ("placeholder_0", TGen.getRand(tuple(t_shape_list), dtype, rng)) + ) + + tf_consts.append(("shape", tuple(s_shape_list))) + + return tf_placeholders, tf_consts diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index f872888..1b681d2 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1239,3 +1239,11 @@ class TBuilder: def eval(self, a): return tf.math.imag(a, name=self.result_name) + + class BroadcastTo: + def __init__(self, shape, name): + self.shape = shape + self.result_name = name + + def eval(self, a): + return tf.broadcast_to(a, shape=self.shape, name=self.result_name) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 58df588..ccbe742 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -860,6 +860,13 @@ TF_OP_LIST = { "tflite": [tf.complex64], }, }, + "broadcastto": { + "operands": (1, 1), + "build_fcn": (TBuilder.BroadcastTo, TGen.tgBroadcastTo, ArgGen.agNone), + "types": { + "tf": TYPE_FIB, + }, + }, } # Shapes to be tested; default can be overwritten -- cgit v1.2.1