aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/frameworks/arg_gen.py17
-rw-r--r--verif/frameworks/test_builder.py5
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py5
3 files changed, 20 insertions, 7 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 61a1de0..18e8976 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -570,12 +570,21 @@ class ArgGen:
rank = len(shapes)
for left in range(3):
for right in range(3):
- paddings = np.zeros((rank, 2), dtype=np.int32)
+ # Padding nothing in tensorflow lite causes the interpreter fail to set
+ # the input tensor properly due to date type mismatch.
+ if (left == 0) and (right == 0):
+ continue
+
+ # A simple way to generate explicit pad_const including zero.
+ pad_const = (left - right) * rng.integers(0, 5, dtype=np.int32)
+ padding = np.zeros((rank, 2), dtype=np.int32)
for d in range(rank):
- paddings[d, 0] = left
- paddings[d, 1] = right
+ padding[d, 0] = left
+ padding[d, 1] = right
- arg_list.append(["_pad{}{}".format(left, right), [paddings]])
+ arg_list.append(
+ ["_pad{}{}".format(left, right), [padding, pad_const]]
+ )
return arg_list
def agFill(op, shapes, rng):
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 407902a..d995a34 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -793,8 +793,9 @@ class TBuilder:
)
class Pad:
- def __init__(self, padding, name):
+ def __init__(self, padding, pad_const, name):
self.padding = padding
+ self.pad_const = pad_const
self.result_name = name
def eval(self, a):
@@ -802,7 +803,7 @@ class TBuilder:
a,
self.padding,
mode="CONSTANT",
- constant_values=0,
+ constant_values=self.pad_const,
name=self.result_name,
)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 60106a1..931aef6 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -631,7 +631,10 @@ TF_OP_LIST = {
"pad": {
"operands": (1, 0),
"build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad),
- "types": TYPE_F,
+ "types": {
+ "tf": TYPE_F,
+ "tflite": list(TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8]),
+ },
},
"expand_dims": {
"operands": (1, 0),