aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 84e4d46..0468518 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -479,6 +479,47 @@ class TBuilder:
)
return bias_add_op
+ class Conv3d:
+ def __init__(self, weight, strides, padding, dilations, name):
+ self.weight = weight
+ self.strides = strides
+ self.padding = padding
+ self.dilations = dilations
+ self.result_name = name
+
+ def eval(self, input):
+ return tf.nn.conv3d(
+ input,
+ self.weight,
+ self.strides,
+ self.padding,
+ data_format="NDHWC",
+ dilations=self.dilations,
+ name=self.result_name,
+ )
+
+ class Conv3dWithBias:
+ def __init__(self, weight, bias, strides, padding, dilations, name):
+ self.weight = weight
+ self.bias = bias
+ self.strides = strides
+ self.padding = padding
+ self.dilations = dilations
+ self.result_name = name
+
+ def eval(self, input):
+ conv3d_op = tf.nn.conv3d(
+ input,
+ self.weight,
+ self.strides,
+ self.padding,
+ data_format="NDHWC",
+ dilations=self.dilations,
+ name="conv3d",
+ )
+ bias_add_op = tf.nn.bias_add(conv3d_op, self.bias, name=self.result_name)
+ return bias_add_op
+
class DepthwiseConv2d:
def __init__(self, weight, strides, padding, dilations, name):
self.weight = weight