aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r--verif/frameworks/arg_gen.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 8feb9b2..fa4a652 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -92,6 +92,16 @@ class ArgGen:
):
continue
+ if (
+ (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
+ != 0
+ ) or (
+ (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
+ != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_dilat{}{}".format(
@@ -147,6 +157,14 @@ class ArgGen:
if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
continue
+ if (
+ (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
+ ) or (
+ (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_dilat{}{}".format(
@@ -217,6 +235,12 @@ class ArgGen:
):
continue
+ if ((shapes[1] - kernel_h) % stride_h != 0) or (
+ (shapes[2] - kernel_w) % stride_w != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_kern{}{}".format(