aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 729d435a..cbd5d6cc 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -22,10 +22,11 @@ class SupportedOperators:
def __init__(self):
# Categorised lists of supported operators
self.npu_pre_ops = set(("QuantizedResizeBilinear", "SplitSliceRead"))
- self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D", "Conv2DBackpropInputSwitched"))
+ self.convolution_ops = set(("Conv2DBiasAct", "Conv2D", "QuantizedConv2D"))
self.depthwise_convolution_ops = set(
("DepthwiseConv2dBiasAct", "DepthwiseConv2dNative", "QuantizedDepthwiseConv2D")
)
+ self.transpose_convolution_ops = set(("Conv2DBackpropInput",))
self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct"))
self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct"))
self.pooling_ops = self.max_pooling_ops | self.avg_pooling_ops
@@ -36,6 +37,8 @@ class SupportedOperators:
self.convolution_ops
# depth-wise convolutions
| self.depthwise_convolution_ops
+ # transpose convolutions
+ | self.transpose_convolution_ops
# pooling
| self.pooling_ops
# resizing/upscaling
@@ -90,6 +93,9 @@ class SupportedOperators:
self.supported_operator_restrictions.update(
{op: self.check_depthwise_convolution_restrictions for op in self.depthwise_convolution_ops}
)
+ self.supported_operator_restrictions.update(
+ {op: self.check_transpose_convolution_restrictions for op in self.transpose_convolution_ops}
+ )
self.supported_operator_restrictions.update({op: self.check_pooling_restrictions for op in self.pooling_ops})
self.supported_operator_restrictions.update({op: self.check_resize_restrictions for op in self.resizing_ops})
self.supported_operator_restrictions.update(
@@ -180,6 +186,28 @@ class SupportedOperators:
return False
return self.check_convolution_restrictions(op)
+ def check_transpose_convolution_restrictions(self, op):
+ # check stride
+ stride_h, stride_w = op.attrs["stride_h"], op.attrs["stride_w"]
+ if stride_h != stride_w != 2:
+ return False
+
+ # check output dimensions
+ ifm_tensor, weight_tensor, _, ofm_tensor = op.get_ifm_weights_biases_ofm()
+ ifm_h, ifm_w = ifm_tensor.shape[1], ifm_tensor.shape[2]
+ ofm_h, ofm_w = ofm_tensor.shape[1], ofm_tensor.shape[2]
+ if op.attrs["padding"] == b"SAME":
+ if (ofm_h != ifm_h * stride_h) or (ofm_w != ifm_w * stride_w):
+ return False
+ elif op.attrs["padding"] == b"VALID":
+ kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1]
+ if ((ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0))
+ or (ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0))):
+ return False
+
+ return self.check_convolution_restrictions(op)
+
+
def check_pooling_restrictions(self, op):
# check stride
if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3: