diff options
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 77 |
1 files changed, 56 insertions, 21 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 5cd9d210..d32955d5 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -149,7 +149,7 @@ def remove_const_transpose(op, arch, nng): # TODO can we change to add for both TFLite and TOSA? -def insert_add_copy_op_after_tens(tens): +def insert_add_copy_op_after_tens(tens, ifm_ofm_shape): tens_cons_list_copy = tens.consumer_list.copy() copy_tens = tens.clone() @@ -166,7 +166,9 @@ def insert_add_copy_op_after_tens(tens): copy_op.add_input_tensor(tens) copy_op.add_input_tensor(ifm2) copy_op.set_output_tensor(copy_tens) - copy_op.set_ifm_ofm_shapes() + copy_op.ifm_shapes.append(ifm_ofm_shape) + copy_op.ifm_shapes.append(Shape4D(ifm2.shape)) + copy_op.ofm_shapes.append(ifm_ofm_shape) copy_op.run_on_npu = True # Set copy_ifm consumers @@ -200,7 +202,29 @@ def fix_sg_input_output_tosa(op, arch, nng): if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed): # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape - insert_add_copy_op_after_tens(op.ifm) + + # Decide on ifm/ofm shapes for the copy op based on ifm + shape = op.ifm.shape.copy() + # remove dimensions that are set to 1 + new_shape = [] + for dim in shape: + if dim != 1: + new_shape.append(dim) + if not new_shape: + new_shape = [1] + + rank = len(new_shape) + if rank > 3: + # Reshape so that batch becomes 1, by moving elements to H dimension + n = rank - 2 + h = 1 + for i in range(n): + h *= shape[i] + new_shape = Shape4D(new_shape[n:]).with_height(h) + else: + new_shape = Shape4D(new_shape) + + insert_add_copy_op_after_tens(op.ifm, new_shape) return op @@ -435,16 +459,12 @@ def convert_pad(op, arch, nng): quant = ofm.quantization pad_value = ifm.quantization.zero_point + ifm.quantization.zero_point = 0 # Add operations that fill the borders of the OFM if top > 0: shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth) zero_tens = create_const_tensor( - op.name + "_top", - shape.as_list(), - ofm.dtype, - shape.elements() * [pad_value], - np.uint8, - quantization=quant, # TODO + op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant, ) # If top/bottom or left/right are equal, the const tensors can be allocated to the same address zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values)) @@ -569,6 +589,16 @@ def get_nhwc_stride(shape): return Shape4D(stride_n, stride_y, stride_x, 1) +def pad_to_rank(shape, rank): + """ + Pads a shape to the given rank + """ + while len(shape) < rank: + shape = [1] + shape + + return shape + + def get_elem_shapes_removed_singles(op): """ Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm @@ -579,7 +609,12 @@ def get_elem_shapes_removed_singles(op): if binary: ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape - rank = len(ofm_shape) + rank = max(len(ofm_shape), len(ifm_shape), len(ifm2_shape) if binary else 0) + ofm_shape = pad_to_rank(ofm_shape, rank) + ifm_shape = pad_to_rank(ifm_shape, rank) + if binary: + ifm2_shape = pad_to_rank(ifm2_shape, rank) + new_ofm_shape = [] new_ifm_shape = [] new_ifm2_shape = [] @@ -777,6 +812,17 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False ) + # Handle sg input output + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False, + ) + + # Removal of reshapes + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) + sg.refresh_after_modification() + # Decomposing of elementwise for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( @@ -794,17 +840,6 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False, ) - # Handle sg input output - for idx, sg in enumerate(nng.subgraphs): - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False, - ) - - # Removal of reshapes - for sg in nng.subgraphs: - rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) - sg.refresh_after_modification() - # TODO, when and where to best handle calc_scaling_avgpool for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( |