diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 2d1245b0..49fc997d 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -35,6 +35,7 @@ from .operation_util import create_add_nop from .operation_util import create_avgpool_nop from .shape4d import Shape4D from .tensor import create_const_tensor +from .tensor import create_equivalence_id def replace_rescale_with_avg_pool(rescale_op): @@ -417,6 +418,96 @@ def rewrite_rescale(op, arch, nng): return op +# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered +def convert_pad(op, arch, nng): + """ + Rewrites PAD operator to an add that copies the IFM to the OFM + + up to 4 add operators that fill the OFM with zeros at the borders. + """ + + if op.type != Op.Pad: + return op + + # TODO assuming rank <= 4 and N = 1 for rank ==4 + # This is checked in tosa_supported_operators + ifm = op.ifm + assert ifm is not None + ifm_shape = Shape4D(ifm.shape) + ofm = op.ofm + assert ofm is not None + ofm.ops = [] + ofm_shape = op.ofm_shapes[0] + + rank = len(ifm.shape) + padding = op.inputs[1].values + pad_depth = padding[-1] + if not (pad_depth == 0).all(): + print("Warning: For PAD, padding in depth not supported yet") + assert False + + top, bottom = 0, 0 + left, right = 0, 0 + if rank > 1: + left, right = padding[-2][0], padding[-2][1] + if rank > 2: + top, bottom = padding[-3][0], padding[-3][1] + if rank == 4 and not (padding[-4] == 0).all(): + print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet") + assert False + + # Add op that copies IFM to the right place inside the OFM + shp0 = Shape4D(0, 0, 0, 0) + shp_top = shp0.with_height(top) + add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left)) + add_op.activation = op.activation + + quant = ofm.quantization + pad_value = ifm.quantization.zero_point + # Add operations that fill the borders of the OFM + if top > 0: + shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth) + zero_tens = create_const_tensor( + op.name + "_top", + shape.as_list(), + ofm.dtype, + shape.elements() * [pad_value], + np.uint8, + quantization=quant, # TODO + ) + # If top/bottom or left/right are equal, the const tensors can be allocated to the same address + zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values)) + create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0) + if bottom > 0: + shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth) + zero_tens = create_const_tensor( + op.name + "_bottom", + shape.as_list(), + ofm.dtype, + shape.elements() * [pad_value], + np.uint8, + quantization=quant, + ) + zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values)) + create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)) + if left > 0: + shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth) + zero_tens = create_const_tensor( + op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant + ) + zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values)) + create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top) + if right > 0: + shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth) + zero_tens = create_const_tensor( + op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant + ) + zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values)) + create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)) + + op.type = Op.ConcatTFLite + return add_op + + def fixup_quantization(op, arch, nng): if op.ifm and op.ifm.quantization.zero_point is None: op.ifm.quantization.zero_point = 0 @@ -484,7 +575,7 @@ def tosa_optimise_graph(nng, arch): # Post-processing step 1 for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [rewrite_activation, add_padding_fields], + nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields], ) # Removal of Slice, need to be done after optimisation has been performed, |