From 2226f90d5a6c48a975045bc9e0419113ce764aaf Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Wed, 22 Feb 2023 18:38:01 -0800 Subject: Extend pad tests to include quantized type and explicit pad value. Signed-off-by: TatWai Chong Change-Id: I4a49f45aa73044aff5b0a8b3dba58c1f52c1ae21 --- verif/frameworks/arg_gen.py | 17 +++++++++++++---- verif/frameworks/test_builder.py | 5 +++-- verif/frameworks/tosa_verif_framework_generator.py | 5 ++++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index 61a1de0..18e8976 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -570,12 +570,21 @@ class ArgGen: rank = len(shapes) for left in range(3): for right in range(3): - paddings = np.zeros((rank, 2), dtype=np.int32) + # Padding nothing in tensorflow lite causes the interpreter fail to set + # the input tensor properly due to date type mismatch. + if (left == 0) and (right == 0): + continue + + # A simple way to generate explicit pad_const including zero. + pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32) + padding = np.zeros((rank, 2), dtype=np.int32) for d in range(rank): - paddings[d, 0] = left - paddings[d, 1] = right + padding[d, 0] = left + padding[d, 1] = right - arg_list.append(["_pad{}{}".format(left, right), [paddings]]) + arg_list.append( + ["_pad{}{}".format(left, right), [padding, pad_const]] + ) return arg_list def agFill(op, shapes, rng): diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 407902a..d995a34 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -793,8 +793,9 @@ class TBuilder: ) class Pad: - def __init__(self, padding, name): + def __init__(self, padding, pad_const, name): self.padding = padding + self.pad_const = pad_const self.result_name = name def eval(self, a): @@ -802,7 +803,7 @@ class TBuilder: a, self.padding, mode="CONSTANT", - constant_values=0, + constant_values=self.pad_const, name=self.result_name, ) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 60106a1..931aef6 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -631,7 +631,10 @@ TF_OP_LIST = { "pad": { "operands": (1, 0), "build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad), - "types": TYPE_F, + "types": { + "tf": TYPE_F, + "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]), + }, }, "expand_dims": { "operands": (1, 0), -- cgit v1.2.1