diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2022-09-09 09:35:40 +0000 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2022-09-22 17:09:48 -0700 |
commit | f7008da16ed36fce2866e0a4a2595acc8f0a27d6 (patch) | |
tree | 684fbbf6a83726aa3538b36f36f0bd9274034c77 /verif/frameworks/arg_gen.py | |
parent | 3d6de004bfa6469a2f90eb9c8c5856095f96467d (diff) | |
download | reference_model-f7008da16ed36fce2866e0a4a2595acc8f0a27d6.tar.gz |
Add framework test for TF and TFL mirrorpad
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Icc9b8f6a65ee54ddbb445c3a999ca49401b660c2
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r-- | verif/frameworks/arg_gen.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index 5467fa2..d81c3dd 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -525,6 +525,43 @@ class ArgGen: axes.append(["_axis{}".format(i), [i]]) return axes + def agMirrorPad(op, shapes, rng): + arg_list = [] + + rank = len(shapes) + for mode in ["REFLECT", "SYMMETRIC"]: + for left in range(3): + for right in range(3): + paddings = np.zeros((rank, 2), dtype=np.int32) + is_valid = True + + # Fill in the padding parameter if the values are valid on each dimension, + # otherwise drop that case. + for d in range(rank): + paddings[d, 0] = left + paddings[d, 1] = right + + # In "REFLECT" mode, paddings must be no greater than tensor dim size - 1. + if mode == "REFLECT": + if (left > shapes[d] - 1) or (right > shapes[d] - 1): + is_valid = False + break + + # In "SYMMETRIC" mode, paddings must be no greater than tensor dim size. + else: + if (left > shapes[d]) or (right > shapes[d]): + is_valid = False + break + + if is_valid: + arg_list.append( + [ + "_pad{}{}_{}".format(left, right, mode[0:3].lower()), + [paddings, mode], + ] + ) + return arg_list + def agPad(op, shapes, rng): arg_list = [] |