aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/frameworks/arg_gen.py37
-rw-r--r--verif/frameworks/test_builder.py15
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py5
3 files changed, 57 insertions, 0 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 5467fa2..d81c3dd 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -525,6 +525,43 @@ class ArgGen:
axes.append(["_axis{}".format(i), [i]])
return axes
+ def agMirrorPad(op, shapes, rng):
+ arg_list = []
+
+ rank = len(shapes)
+ for mode in ["REFLECT", "SYMMETRIC"]:
+ for left in range(3):
+ for right in range(3):
+ paddings = np.zeros((rank, 2), dtype=np.int32)
+ is_valid = True
+
+ # Fill in the padding parameter if the values are valid on each dimension,
+ # otherwise drop that case.
+ for d in range(rank):
+ paddings[d, 0] = left
+ paddings[d, 1] = right
+
+ # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
+ if mode == "REFLECT":
+ if (left > shapes[d] - 1) or (right > shapes[d] - 1):
+ is_valid = False
+ break
+
+ # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
+ else:
+ if (left > shapes[d]) or (right > shapes[d]):
+ is_valid = False
+ break
+
+ if is_valid:
+ arg_list.append(
+ [
+ "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
+ [paddings, mode],
+ ]
+ )
+ return arg_list
+
def agPad(op, shapes, rng):
arg_list = []
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 0468518..cd7831d 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -737,6 +737,21 @@ class TBuilder:
)
return tf.stack(sums, 0, name=self.result_name)
+ class MirrorPad:
+ def __init__(self, padding, mode, name):
+ self.padding = padding
+ self.mode = mode
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.pad(
+ a,
+ self.padding,
+ mode=self.mode,
+ constant_values=0,
+ name=self.result_name,
+ )
+
class Pad:
def __init__(self, padding, name):
self.padding = padding
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index fb7f35a..4c710bd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -592,6 +592,11 @@ TF_OP_LIST = {
"build_fcn": (TBuilder.Unstack, TGen.tgPooling, ArgGen.agAxes),
"types": TYPE_F,
},
+ "mirrorpad": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.MirrorPad, TGen.tgBasic, ArgGen.agMirrorPad),
+ "types": TYPE_FI,
+ },
"pad": {
"operands": (1, 0),
"build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad),