diff options
Diffstat (limited to 'verif/frameworks')
-rw-r--r-- | verif/frameworks/test_builder.py | 39 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 24 |
2 files changed, 63 insertions, 0 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 97b9085..8677559 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1021,6 +1021,45 @@ class TBuilder: ) return tf.identity(resize, name=self.result_name) + # New tf resize set (align_corners, half_pixel_centers) = (false, true) by default. + # Test the rest option combinations here. + # Note that (align_corners, half_pixel_centers) = (true, true) is NOT valid. + class ResizeBilinearV1AlignCorners: + def __init__(self, name): + self.result_name = name + + def eval(self, a): + out_shape = [] + out_shape.append(a.shape[1] * 2) + out_shape.append(a.shape[2] * 2) + + resize = tf.compat.v1.image.resize_bilinear( + a, + out_shape, + align_corners=True, + name="resize", + half_pixel_centers=False, + ) + return tf.identity(resize, name=self.result_name) + + class ResizeBilinearV1None: + def __init__(self, name): + self.result_name = name + + def eval(self, a): + out_shape = [] + out_shape.append(a.shape[1] * 2) + out_shape.append(a.shape[2] * 2) + + resize = tf.compat.v1.image.resize_bilinear( + a, + out_shape, + align_corners=False, + name="resize", + half_pixel_centers=False, + ) + return tf.identity(resize, name=self.result_name) + class LeftShift: def __init__(self, shift, name): self.shift = shift diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 8d8b155..4167227 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -698,6 +698,30 @@ TF_OP_LIST = { ), }, }, + "resize_bilinear_v1_align_corners": { + "operands": (1, 0), + "build_fcn": ( + TBuilder.ResizeBilinearV1AlignCorners, + TGen.tgPooling, + ArgGen.agNone, + ), + "types": { + "tf": TYPE_F, + "tflite": list( + TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16] + ), + }, + }, + "resize_bilinear_v1_none": { + "operands": (1, 0), + "build_fcn": (TBuilder.ResizeBilinearV1None, TGen.tgPooling, ArgGen.agNone), + "types": { + "tf": TYPE_F, + "tflite": list( + TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16] + ), + }, + }, "left_shift": { "operands": (1, 0), "build_fcn": (TBuilder.LeftShift, TGen.tgBasic, ArgGen.agShift), |