diff options
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r-- | ethosu/vela/supported_operators.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index ce3fa609..729d435a 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -29,6 +29,7 @@ class SupportedOperators: self.max_pooling_ops = set(("QuantizedMaxPool", "MaxPool", "MaxPoolAct")) self.avg_pooling_ops = set(("QuantizedAvgPool", "AvgPool", "AvgPoolAct")) self.pooling_ops = self.max_pooling_ops | self.avg_pooling_ops + self.resizing_ops = set(("ResizeBilinear",)) self.fc_vector_products = set(("QuantizedMatMul", "MatMul", "FullyConnectedAct")) self.mac_main_ops = ( # convolutions @@ -37,12 +38,12 @@ class SupportedOperators: | self.depthwise_convolution_ops # pooling | self.pooling_ops + # resizing/upscaling + | self.resizing_ops # FC layers | self.fc_vector_products # RNN/LSTM/GRU | set(("BlockLSTM")) - # deconvolution - | set(("ResizeBilinear",)) ) self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs")) self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum")) @@ -90,6 +91,7 @@ class SupportedOperators: {op: self.check_depthwise_convolution_restrictions for op in self.depthwise_convolution_ops} ) self.supported_operator_restrictions.update({op: self.check_pooling_restrictions for op in self.pooling_ops}) + self.supported_operator_restrictions.update({op: self.check_resize_restrictions for op in self.resizing_ops}) self.supported_operator_restrictions.update( {op: self.check_vector_product_restrictions for op in self.fc_vector_products} ) @@ -206,6 +208,17 @@ class SupportedOperators: return False return True + def check_resize_restrictions(self, op): + # check unsupported upscaling factor + if op.type == "ResizeBilinear": + upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2] + out_shape = op.outputs[0].shape[1:3] + if not op.attrs["align_corners"] and out_shape != upscaled_shape: + return False + elif op.attrs["align_corners"] and out_shape != [upscaled_shape[0] - 1, upscaled_shape[1] - 1]: + return False + return True + def check_vector_product_restrictions(self, op): # check data type ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm() |