diff options
Diffstat (limited to 'ethosu/vela/high_level_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream_generator.py | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index ef21e06c..0cc70a7f 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -24,17 +24,18 @@ from .high_level_command_stream import NpuStripe from .nn_graph import PassPlacement from .nn_graph import SchedulingStrategy from .operation import NpuBlockType +from .tensor import TensorPurpose def need_dma(tens): return len(tens.ops) == 1 and tens.ops[0].type == "DMA" -def dma_weights_if_necessary(ps, box, weight_tensor): - if need_dma(weight_tensor): - dma_op = weight_tensor.ops[0] +def dma_if_necessary(ps, box, tensor): + if need_dma(tensor): + dma_op = tensor.ops[0] in_tensor = dma_op.inputs[0] - yield DMA(in_tensor, weight_tensor, box) + yield DMA(in_tensor, tensor, box) def generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx): @@ -115,6 +116,13 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id else: ifm2_box = Box([], []) + for intermediate in ps.intermediates: + if intermediate != None and intermediate.shape != [] and intermediate.purpose == TensorPurpose.FeatureMap: + intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( + strides, skirt, intermediate.shape, npu_block_type, concat_axis, concat_offset, split_offsets[0] + ) + yield from dma_if_necessary(ps, intermediate_box, intermediate) + weight_box = None if weight_tensor is not None: weight_oc_start = start @@ -130,7 +138,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id weight_oc_end, weight_tensor.weight_transpose_depthwise, ) - yield from dma_weights_if_necessary(ps, weight_box, weight_tensor) + yield from dma_if_necessary(ps, weight_box, weight_tensor) yield NpuStripe( ps, @@ -201,6 +209,13 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id strides, skirt, ifm_tensor.shape, npu_block_type, concat_axis, concat_offset, split_offsets[0], k_height ) + for intermediate in ps.intermediates: + if intermediate != None and intermediate.shape != [] and intermediate.purpose == TensorPurpose.FeatureMap: + intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt( + strides, skirt, intermediate.shape, npu_block_type, concat_axis, concat_offset, split_offsets[0] + ) + yield from dma_if_necessary(ps, intermediate_box, intermediate) + ifm_y_needed = 1 if len(ifm_box.end_coord) >= 3: ifm_y_needed = ifm_box.end_coord[-3] @@ -217,7 +232,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id weight_box = Box.make_weight_box( weight_tensor.shape, npu_block_type, weights_transposed=weight_tensor.weight_transpose_depthwise ) - yield from dma_weights_if_necessary(ps, weight_box, weight_tensor) + yield from dma_if_necessary(ps, weight_box, weight_tensor) # Check if first/last stripe in pass is_first_h_stripe = start == y_start |