aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r--verif/frameworks/arg_gen.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index fa4a652..5467fa2 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -120,6 +120,95 @@ class ArgGen:
)
return arg_list
+ # conv3d argument generators build the TF constants
+ def agConv3d(op, shapes, rng):
+ arg_list = []
+
+ # input shape = [OC, KD, KH, KW, IC]
+ # Must be rank 5
+ if len(shapes) != 5:
+ return arg_list
+
+ if len(op["filter"]) < 3:
+ return arg_list
+
+ filter_d, filter_h, filter_w = op["filter"]
+
+ # strides, padding, dilations,
+ for stride_d in [1, 2]:
+ for stride_h in [1, 2]:
+ for stride_w in [1, 2]:
+ for padding in ["SAME", "VALID"]:
+ for dilation_d in [1, 2]:
+ for dilation_h in [1, 2]:
+ for dilation_w in [1, 2]:
+
+ # Disqualify argument combinations that would cause
+ # an illegal convolution
+ # fmt: off
+ if (padding == "VALID") and (
+ (shapes[1] - (filter_d - 1) * 2 - dilation_d) <= 0
+ or (shapes[2] - (filter_h - 1) * 2 - dilation_h) <= 0
+ or (shapes[3] - (filter_w - 1) * 2 - dilation_w) <= 0
+ ):
+ continue
+
+ if (
+ (shapes[1] - 1 - (filter_d - 1) * dilation_d) % stride_d
+ != 0
+ ) or (
+ (shapes[2] - 1 - (filter_h - 1) * dilation_h) % stride_h
+ != 0
+ ) or (
+ (shapes[3] - 1 - (filter_w - 1) * dilation_w) % stride_w
+ != 0
+ ):
+ # Not an exact integer output
+ continue
+ # fmt: on
+
+ # TODO investigate the error of `CPU implementation of Conv3D
+ # currently only supports dilated rates of 1.` from Tensorflow.
+ # Only test dilations = [1, 1, 1, 1, 1] for now.
+ if (
+ (dilation_d != 1)
+ or (dilation_h != 1)
+ or (dilation_w != 1)
+ ):
+ continue
+
+ # Tensorflow expects strides is a list of ints that has length >= 5.
+ # Strides and dilations in the batch and depth dimensions must be 1.
+ arg_list.append(
+ [
+ "_st{}{}{}{}{}_pad{}_dilat{}{}{}{}{}".format(
+ 1,
+ stride_d,
+ stride_h,
+ stride_w,
+ 1,
+ padding,
+ 1,
+ dilation_d,
+ dilation_h,
+ dilation_w,
+ 1,
+ ),
+ [
+ [1, stride_d, stride_h, stride_w, 1],
+ padding,
+ [
+ 1,
+ dilation_d,
+ dilation_h,
+ dilation_w,
+ 1,
+ ],
+ ],
+ ]
+ )
+ return arg_list
+
# conv2d argument generators build the TF constants
def agDepthwiseConv2d(op, shapes, rng):
arg_list = []