diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 220ba1a9..c0099ffa 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -26,6 +26,7 @@ from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .errors import VelaError from .operation import Op +from .operation import Operation from .operation_util import create_avgpool_nop from .shape4d import Shape4D from .tensor import Tensor @@ -192,8 +193,8 @@ def needed_total_padding(input_size, stride, filter_size): return max(filter_size - (input_size % stride), 0) -# Set input/output tensor equivalence to the same id for memory operations -def set_tensor_equivalence(op, arch, nng): +def set_tensor_equivalence(op: Operation, arch, nng) -> Operation: + """Set input/output tensor equivalence to the same id for memory operations.""" if op.type in memory_only_ops: eid = op.outputs[0].equivalence_id for inp in op.inputs: @@ -300,16 +301,16 @@ def bypass_memory_only_ops(op, arch, nng): return op -def convert_depthwise_to_conv(op, arch, nng): - # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and - # the ofm depth equals the depth multipler. - # If those conditions are true, then we can perform a simple - # switch of the operator type (and weight order) - +def convert_depthwise_to_conv(op: Operation, arch, nng) -> Operation: + """Convert DepthwiseConv2DBias to Conv2D to allow support for DepthwiseConv2DBias ops with 'depth multiplier' > 1, + as long as IFM depth = 1 and OFM depth is equal to the depth multiplier. + """ if op.type == Op.DepthwiseConv2DBias and (op.attrs["depth_multiplier"] != 1): ifm_shape = op.ifm_shapes[0] weight_tensor = op.inputs[1] ofm_shape = op.ofm_shapes[0] + # Depthwise is equivalent to a single conv2d if the ifm depth is 1 and + # the ofm depth equals the depth multipler. if (ifm_shape.depth == 1) and (ofm_shape.depth == op.attrs["depth_multiplier"]): # Change op type to Conv2d op.type = Op.Conv2DBias @@ -321,8 +322,8 @@ def convert_depthwise_to_conv(op, arch, nng): DebugDatabase.add_optimised(op, op) else: raise UnsupportedFeatureError( - f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']},", - f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}", + f"Unsupported 'DEPTHWISE_CONV_2D' with depth_multiplier = {op.attrs['depth_multiplier']}," + f" ifm channels = {ifm_shape.depth}, ofm channels = {ofm_shape.depth}" ) return op |