From f732609a51630c98bc3f448937988fbcf20dc854 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Wed, 8 Jun 2022 12:17:14 -0700 Subject: Update TOSA resize to match specification Attribute stride and shift are removed, and has new scale and border. Also add tests in the generator to test tf.resize with all option combinations. Signed-off-by: TatWai Chong Signed-off-by: Jeremy Johnson Change-Id: If0f330d04395762d2d907863235eda1532f5e1ff --- verif/frameworks/test_builder.py | 39 ++++++++++++++++++++++ verif/frameworks/tosa_verif_framework_generator.py | 24 +++++++++++++ 2 files changed, 63 insertions(+) (limited to 'verif/frameworks') diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 97b9085..8677559 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1021,6 +1021,45 @@ class TBuilder: ) return tf.identity(resize, name=self.result_name) + # New tf resize set (align_corners, half_pixel_centers) = (false, true) by default. + # Test the rest option combinations here. + # Note that (align_corners, half_pixel_centers) = (true, true) is NOT valid. + class ResizeBilinearV1AlignCorners: + def __init__(self, name): + self.result_name = name + + def eval(self, a): + out_shape = [] + out_shape.append(a.shape[1] * 2) + out_shape.append(a.shape[2] * 2) + + resize = tf.compat.v1.image.resize_bilinear( + a, + out_shape, + align_corners=True, + name="resize", + half_pixel_centers=False, + ) + return tf.identity(resize, name=self.result_name) + + class ResizeBilinearV1None: + def __init__(self, name): + self.result_name = name + + def eval(self, a): + out_shape = [] + out_shape.append(a.shape[1] * 2) + out_shape.append(a.shape[2] * 2) + + resize = tf.compat.v1.image.resize_bilinear( + a, + out_shape, + align_corners=False, + name="resize", + half_pixel_centers=False, + ) + return tf.identity(resize, name=self.result_name) + class LeftShift: def __init__(self, shift, name): self.shift = shift diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 8d8b155..4167227 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -698,6 +698,30 @@ TF_OP_LIST = { ), }, }, + "resize_bilinear_v1_align_corners": { + "operands": (1, 0), + "build_fcn": ( + TBuilder.ResizeBilinearV1AlignCorners, + TGen.tgPooling, + ArgGen.agNone, + ), + "types": { + "tf": TYPE_F, + "tflite": list( + TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16] + ), + }, + }, + "resize_bilinear_v1_none": { + "operands": (1, 0), + "build_fcn": (TBuilder.ResizeBilinearV1None, TGen.tgPooling, ArgGen.agNone), + "types": { + "tf": TYPE_F, + "tflite": list( + TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16] + ), + }, + }, "left_shift": { "operands": (1, 0), "build_fcn": (TBuilder.LeftShift, TGen.tgBasic, ArgGen.agShift), -- cgit v1.2.1