aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2022-09-09 09:35:40 +0000
committerTatWai Chong <tatwai.chong@arm.com>2022-09-22 17:09:48 -0700
commitf7008da16ed36fce2866e0a4a2595acc8f0a27d6 (patch)
tree684fbbf6a83726aa3538b36f36f0bd9274034c77
parent3d6de004bfa6469a2f90eb9c8c5856095f96467d (diff)
downloadreference_model-f7008da16ed36fce2866e0a4a2595acc8f0a27d6.tar.gz
Add framework test for TF and TFL mirrorpad
Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Icc9b8f6a65ee54ddbb445c3a999ca49401b660c2
-rw-r--r--verif/frameworks/arg_gen.py37
-rw-r--r--verif/frameworks/test_builder.py15
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py5
3 files changed, 57 insertions, 0 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 5467fa2..d81c3dd 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -525,6 +525,43 @@ class ArgGen:
axes.append(["_axis{}".format(i), [i]])
return axes
+ def agMirrorPad(op, shapes, rng):
+ arg_list = []
+
+ rank = len(shapes)
+ for mode in ["REFLECT", "SYMMETRIC"]:
+ for left in range(3):
+ for right in range(3):
+ paddings = np.zeros((rank, 2), dtype=np.int32)
+ is_valid = True
+
+ # Fill in the padding parameter if the values are valid on each dimension,
+ # otherwise drop that case.
+ for d in range(rank):
+ paddings[d, 0] = left
+ paddings[d, 1] = right
+
+ # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1.
+ if mode == "REFLECT":
+ if (left > shapes[d] - 1) or (right > shapes[d] - 1):
+ is_valid = False
+ break
+
+ # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size.
+ else:
+ if (left > shapes[d]) or (right > shapes[d]):
+ is_valid = False
+ break
+
+ if is_valid:
+ arg_list.append(
+ [
+ "_pad{}{}_{}".format(left, right, mode[0:3].lower()),
+ [paddings, mode],
+ ]
+ )
+ return arg_list
+
def agPad(op, shapes, rng):
arg_list = []
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 0468518..cd7831d 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -737,6 +737,21 @@ class TBuilder:
)
return tf.stack(sums, 0, name=self.result_name)
+ class MirrorPad:
+ def __init__(self, padding, mode, name):
+ self.padding = padding
+ self.mode = mode
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.pad(
+ a,
+ self.padding,
+ mode=self.mode,
+ constant_values=0,
+ name=self.result_name,
+ )
+
class Pad:
def __init__(self, padding, name):
self.padding = padding
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index fb7f35a..4c710bd 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -592,6 +592,11 @@ TF_OP_LIST = {
"build_fcn": (TBuilder.Unstack, TGen.tgPooling, ArgGen.agAxes),
"types": TYPE_F,
},
+ "mirrorpad": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.MirrorPad, TGen.tgBasic, ArgGen.agMirrorPad),
+ "types": TYPE_FI,
+ },
"pad": {
"operands": (1, 0),
"build_fcn": (TBuilder.Pad, TGen.tgBasic, ArgGen.agPad),