aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-06-01 21:45:12 +0000
committerTai Ly <tai.ly@arm.com>2023-06-21 16:48:16 +0000
commitfe36fa9f38824d03250393488fe468b7dacc72ed (patch)
tree62a9aa96b6207113525c5eba401301e7a5d52b3e
parentc8da1d2687cfcff90629c2cf770bb5f406002701 (diff)
downloadreference_model-fe36fa9f38824d03250393488fe468b7dacc72ed.tar.gz
Add tf broadcast-to testing
This patch adds BoradcastTo Op to the tf tests. Did not add tflite testing because the tf.lite.TFLiteConverter converts tf.broadcast-to to tfl.mul by 1. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Icd372e619c318121c19eca87d5716bcd5fbbbb23
-rw-r--r--verif/frameworks/tensor_gen.py38
-rw-r--r--verif/frameworks/test_builder.py8
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py7
3 files changed, 53 insertions, 0 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 60e17ce..170e5d8 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -319,3 +319,41 @@ class TGen:
return [], []
return TGen.tgBasic(op, shape, dtype, rng)
+
+ @staticmethod
+ def tgBroadcastTo(op, shape, dtype, rng):
+
+ pl, const = op["operands"]
+
+ assert pl == 1
+ assert const == 1
+
+ tf_placeholders = []
+ tf_consts = []
+
+ shape_list = list(shape)
+ t_shape_list = []
+ s_shape_list = []
+ for i in range(len(shape)):
+ dim = shape_list[i]
+ if rng.integers(0, 1) == 0:
+ # append dim in s_shape_list, and 1 in t_shape_list unless it is still empty
+ s_shape_list.append(dim)
+ if len(t_shape_list) > 0:
+ t_shape_list.append(1)
+ else:
+ # append 1 in s_shape_list, and dim in t_shape_list
+ s_shape_list.append(1)
+ t_shape_list.append(dim)
+
+ # if t_shape_list is empty, then insert 1
+ if len(t_shape_list) == 0:
+ t_shape_list.append(1)
+
+ tf_placeholders.append(
+ ("placeholder_0", TGen.getRand(tuple(t_shape_list), dtype, rng))
+ )
+
+ tf_consts.append(("shape", tuple(s_shape_list)))
+
+ return tf_placeholders, tf_consts
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index f872888..1b681d2 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1239,3 +1239,11 @@ class TBuilder:
def eval(self, a):
return tf.math.imag(a, name=self.result_name)
+
+ class BroadcastTo:
+ def __init__(self, shape, name):
+ self.shape = shape
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 58df588..ccbe742 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -860,6 +860,13 @@ TF_OP_LIST = {
"tflite": [tf.complex64],
},
},
+ "broadcastto": {
+ "operands": (1, 1),
+ "build_fcn": (TBuilder.BroadcastTo, TGen.tgBroadcastTo, ArgGen.agNone),
+ "types": {
+ "tf": TYPE_FIB,
+ },
+ },
}
# Shapes to be tested; default can be overwritten