aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index ce3fa609..729d435a 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -29,6 +29,7 @@ class SupportedOperators:
self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct"))
self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct"))
self.pooling_ops = self.max_pooling_ops | self.avg_pooling_ops
+ self.resizing_ops = set(("ResizeBilinear",))
self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct"))
self.mac_main_ops = (
# convolutions
@@ -37,12 +38,12 @@ class SupportedOperators:
| self.depthwise_convolution_ops
# pooling
| self.pooling_ops
+ # resizing/upscaling
+ | self.resizing_ops
# FC layers
| self.fc_vector_products
# RNN/LSTM/GRU
| set(("BlockLSTM"))
- # deconvolution
- | set(("ResizeBilinear",))
)
self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
@@ -90,6 +91,7 @@ class SupportedOperators:
{op: self.check_depthwise_convolution_restrictions for op in self.depthwise_convolution_ops}
)
self.supported_operator_restrictions.update({op: self.check_pooling_restrictions for op in self.pooling_ops})
+ self.supported_operator_restrictions.update({op: self.check_resize_restrictions for op in self.resizing_ops})
self.supported_operator_restrictions.update(
{op: self.check_vector_product_restrictions for op in self.fc_vector_products}
)
@@ -206,6 +208,17 @@ class SupportedOperators:
return False
return True
+ def check_resize_restrictions(self, op):
+ # check unsupported upscaling factor
+ if op.type == "ResizeBilinear":
+ upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
+ out_shape = op.outputs[0].shape[1:3]
+ if not op.attrs["align_corners"] and out_shape != upscaled_shape:
+ return False
+ elif op.attrs["align_corners"] and out_shape != [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
+ return False
+ return True
+
def check_vector_product_restrictions(self, op):
# check data type
ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()