diff options
-rw-r--r-- | verif/frameworks/arg_gen.py | 17 | ||||
-rw-r--r-- | verif/frameworks/test_builder.py | 5 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 5 |
3 files changed, 20 insertions, 7 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index 61a1de0..18e8976 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -570,12 +570,21 @@ class ArgGen: rank = len(shapes) for left in range(3): for right in range(3): - paddings = np.zeros((rank, 2), dtype=np.int32) + # Padding nothing in tensorflow lite causes the interpreter fail to set + # the input tensor properly due to date type mismatch. + if (left == 0) and (right == 0): + continue + + # A simple way to generate explicit pad_const including zero. + pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32) + padding = np.zeros((rank, 2), dtype=np.int32) for d in range(rank): - paddings[d, 0] = left - paddings[d, 1] = right + padding[d, 0] = left + padding[d, 1] = right - arg_list.append(["_pad{}{}".format(left, right), [paddings]]) + arg_list.append( + ["_pad{}{}".format(left, right), [padding, pad_const]] + ) return arg_list def agFill(op, shapes, rng): diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 407902a..d995a34 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -793,8 +793,9 @@ class TBuilder: ) class Pad: - def __init__(self, padding, name): + def __init__(self, padding, pad_const, name): self.padding = padding + self.pad_const = pad_const self.result_name = name def eval(self, a): @@ -802,7 +803,7 @@ class TBuilder: a, self.padding, mode="CONSTANT", - constant_values=0, + constant_values=self.pad_const, name=self.result_name, ) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 60106a1..931aef6 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -631,7 +631,10 @@ TF_OP_LIST = { "pad": { "operands": (1, 0), "build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad), - "types": TYPE_F, + "types": { + "tf": TYPE_F, + "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]), + }, }, "expand_dims": { "operands": (1, 0), |