aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2023-02-27 13:22:52 -0800
committerTatWai Chong <tatwai.chong@arm.com>2023-03-03 15:15:38 -0800
commit0cef07eca8c91738093df86903f3584daaf12b9a (patch)
tree6b849112ffc3657dc69ef02a4248fd0b729566e6
parent57bc0796cd85115684219cf373db04c172848306 (diff)
downloadreference_model-0cef07eca8c91738093df86903f3584daaf12b9a.tar.gz
Refactor resize test builder
Also add input size = 1 in the shape list, and extend scaling to 1x, 2x and 3x, so that the cases of broadcasting, power-of-two scaling, no-scaling (e.g. 1x1 -> 1x1), scaling accuracy (3x) can be tested. Since the scalar tensor is tiny, should not noticeably impact the execution time of the framework test. Change-Id: Iec53da3cbb60e087077d6e2d8eb205e76e6c1313 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--verif/frameworks/arg_gen.py24
-rw-r--r--verif/frameworks/test_builder.py81
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py39
3 files changed, 42 insertions, 102 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 18e8976..8604f0b 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -587,6 +587,30 @@ class ArgGen:
)
return arg_list
+ def agResize(op, shapes, rng):
+ args = []
+ for mode in ["nearest", "bilinear"]:
+ for align_corners in [True, False]:
+ for half_pixel in [True, False]:
+ # If half_pixel_centers is True, align_corners must be False.
+ if (
+ (mode == "bilinear")
+ and (align_corners is True)
+ and (half_pixel is True)
+ ):
+ continue
+
+ for i in range(1, 4):
+ args.append(
+ [
+ "_{}_align{}_half{}_scale{}".format(
+ mode, int(align_corners), int(half_pixel), i
+ ),
+ [mode, align_corners, half_pixel, i],
+ ]
+ )
+ return args
+
def agFill(op, shapes, rng):
values = []
for i in range(4):
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index d995a34..c7ba9a9 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1086,81 +1086,30 @@ class TBuilder:
name=self.result_name,
)
- class ResizeNearest:
- def __init__(self, name):
- self.result_name = name
-
- def eval(self, a):
- out_shape = []
- out_shape.append(a.shape[1] * 2)
- out_shape.append(a.shape[2] * 2)
-
- # tf.image.resize() will overwrite the node name with result_name +
- # '/BILINEAR' need to add extra identity to force output tensor name to
- # result_name return tf.image.resize(a, out_shape,
- # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
- resize = tf.image.resize(
- a,
- out_shape,
- method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
- name="resize",
- )
- return tf.identity(resize, name=self.result_name)
-
- class ResizeBilinear:
- def __init__(self, name):
- self.result_name = name
-
- def eval(self, a):
- out_shape = []
- out_shape.append(a.shape[1] * 2)
- out_shape.append(a.shape[2] * 2)
-
- # tf.image.resize() will overwrite the node name with result_name +
- # '/BILINEAR' need to add extra identity to force output tensor name to
- # result_name return tf.image.resize(a, out_shape,
- # method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, name=result_name)
- resize = tf.image.resize(
- a, out_shape, method=tf.image.ResizeMethod.BILINEAR, name="resize"
- )
- return tf.identity(resize, name=self.result_name)
-
- # New tf resize set (align_corners, half_pixel_centers) = (false, true) by default.
- # Test the rest option combinations here.
- # Note that (align_corners, half_pixel_centers) = (true, true) is NOT valid.
- class ResizeBilinearV1AlignCorners:
- def __init__(self, name):
+ class Resize:
+ def __init__(self, mode, align, half, scale, name):
self.result_name = name
+ self.mode = mode
+ self.align = align
+ self.half = half
+ self.scale = scale
def eval(self, a):
out_shape = []
- out_shape.append(a.shape[1] * 2)
- out_shape.append(a.shape[2] * 2)
+ out_shape.append(a.shape[1] * self.scale)
+ out_shape.append(a.shape[2] * self.scale)
- resize = tf.compat.v1.image.resize_bilinear(
- a,
- out_shape,
- align_corners=True,
- name="resize",
- half_pixel_centers=False,
+ tf_resize_dict = (
+ {"tf_resize_func": tf.compat.v1.image.resize_nearest_neighbor}
+ if (self.mode == "nearest")
+ else {"tf_resize_func": tf.compat.v1.image.resize_bilinear}
)
- return tf.identity(resize, name=self.result_name)
-
- class ResizeBilinearV1None:
- def __init__(self, name):
- self.result_name = name
-
- def eval(self, a):
- out_shape = []
- out_shape.append(a.shape[1] * 2)
- out_shape.append(a.shape[2] * 2)
-
- resize = tf.compat.v1.image.resize_bilinear(
+ resize = tf_resize_dict["tf_resize_func"](
a,
out_shape,
- align_corners=False,
+ align_corners=self.align,
name="resize",
- half_pixel_centers=False,
+ half_pixel_centers=self.half,
)
return tf.identity(resize, name=self.result_name)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index f8518f1..68c7d5a 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -762,43 +762,9 @@ TF_OP_LIST = {
),
"types": {"tf": TYPE_F},
},
- "resize_nearest": {
+ "resize": {
"operands": (1, 0),
- "build_fcn": (TBuilder.ResizeNearest, TGen.tgPooling, ArgGen.agNone),
- "types": {
- "tf": TYPE_F,
- "tflite": list(
- TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
- ),
- },
- },
- "resize_bilinear": {
- "operands": (1, 0),
- "build_fcn": (TBuilder.ResizeBilinear, TGen.tgPooling, ArgGen.agNone),
- "types": {
- "tf": TYPE_F,
- "tflite": list(
- TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
- ),
- },
- },
- "resize_bilinear_v1_align_corners": {
- "operands": (1, 0),
- "build_fcn": (
- TBuilder.ResizeBilinearV1AlignCorners,
- TGen.tgPooling,
- ArgGen.agNone,
- ),
- "types": {
- "tf": TYPE_F,
- "tflite": list(
- TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
- ),
- },
- },
- "resize_bilinear_v1_none": {
- "operands": (1, 0),
- "build_fcn": (TBuilder.ResizeBilinearV1None, TGen.tgPooling, ArgGen.agNone),
+ "build_fcn": (TBuilder.Resize, TGen.tgPooling, ArgGen.agResize),
"types": {
"tf": TYPE_F,
"tflite": list(
@@ -878,6 +844,7 @@ shape_list = [
(1, 32, 32, 8),
(1, 7, 7, 9),
(1, 7, 7, 479),
+ (3, 1, 1, 7),
(2, 2, 7, 7, 2),
(1, 4, 8, 21, 17),
(3, 32, 16, 16, 5),