diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-01 16:02:29 +0100 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-18 16:33:32 +0100 |
commit | 2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch) | |
tree | b5151d0f12428e47d64b1fb2ce4f2f8c19304a0d /ethosu/vela/high_level_command_stream_generator.py | |
parent | 528a56df829b65f7a2c61953650b123c461095f7 (diff) | |
download | ethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz |
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op
Changed to rely on these shapes
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
Diffstat (limited to 'ethosu/vela/high_level_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream_generator.py | 60 |
1 files changed, 32 insertions, 28 deletions
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index 905263d6..18a419c0 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -56,6 +56,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id # Ensure correct ifm and ifm2 order if match_tensor(ps.inputs[0], ps.primary_op.inputs[1]) and match_tensor(ps.inputs[1], ps.primary_op.inputs[0]): ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor + ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0] for op in ps.ops: if op.type == Op.SplitSliceRead: @@ -77,13 +78,20 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_idx += 1 ifm_tensor = ps.ifm_tensor + ifm_shape = None + if ifm_tensor.shape != []: + ifm_shape = ps.ifm_shapes[0] ifm2_tensor = ps.ifm2_tensor + ifm2_shape = None + if ifm2_tensor is not None and ifm2_tensor.shape != []: + ifm2_shape = ps.ifm_shapes[1] ofm_tensor = ps.ofm_tensor + ofm_shape = ps.ofm_shapes[0] weight_tensor = ps.weight_tensor scale_tensor = ps.scale_tensor - ofm_start = [0] * len(ofm_tensor.shape) - ofm_end = list(ofm_tensor.shape) + ofm_start = [0] * len(ofm_shape) + ofm_end = list(ofm_shape) strides = None skirt = None @@ -92,9 +100,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id strides = ps.primary_op.attrs.get("strides", None) skirt = ps.primary_op.attrs.get("skirt", None) if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias: - upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3] + upscaling = ofm_shape[-3] // ifm_shape[-3] elif ps.primary_op.type == Op.ResizeBilinear: - upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3]) + upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3]) concat_axis = 0 concat_offset = 0 @@ -125,7 +133,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_box = None ifm2_box = None - if ifm_tensor.shape != []: + if ifm_shape is not None: ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt( strides, skirt, @@ -138,16 +146,9 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ) else: ifm_box = Box([], []) - if ifm2_tensor is not None and ifm2_tensor.shape != []: + if ifm2_shape is not None: ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt( - strides, - skirt, - ifm2_tensor.shape, - npu_block_type, - concat_axis, - concat_offset, - split_offsets[1], - upscaling, + strides, skirt, ifm2_shape, npu_block_type, concat_axis, concat_offset, split_offsets[1], upscaling, ) else: ifm2_box = Box([], []) @@ -212,19 +213,17 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id elif strat == SchedulingStrategy.IfmStream: y_step = block_config[0] - y_start = 0 - y_dim = 1 - if len(ofm_tensor.shape) >= 3: - y_start = ofm_start[-3] - y_dim = ofm_end[-3] + y_start = ofm_start[-3] + y_dim = ofm_end[-3] + if idx > 0: ifm_y_present = 0 prev_pass = passes[idx - 1] prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1) else: ifm_y_present = 1 - if len(ifm_tensor.shape) >= 3: - ifm_y_present = ifm_tensor.shape[-3] + if len(ifm_shape) >= 3: + ifm_y_present = ifm_shape[-3] prev_pass_gen = [] prev_pass = None @@ -243,9 +242,8 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id for start in range(y_start, y_dim, y_step): end = min(start + y_step, y_dim) - if len(ofm_tensor.shape) >= 3: - ofm_start[-3] = start - ofm_end[-3] = end + ofm_start[-3] = start + ofm_end[-3] = end ofm_box = Box(ofm_start, ofm_end) k_height = 1 @@ -259,7 +257,7 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt( strides, skirt, - ifm_tensor.shape, + ifm_shape, npu_block_type, concat_axis, concat_offset, @@ -381,11 +379,15 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): for cmd in generate_high_level_command_stream_for_pass_list(strat, passes, block_configs): if cmd.is_npu_pass_command(): if cmd.is_first: - ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.start_coord, is_top_box=False) + ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( + cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False + ) if ifm_read is None: return 0 if cmd.is_last: - write_offset = cmd.ofm_tensor.address_offset_for_coordinate(cmd.ofm_box.end_coord, is_top_box=True) + write_offset = cmd.ofm_tensor.address_offset_for_coordinate( + cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True + ) if write_offset is None: return 0 highest_ofm_write = max(write_offset, highest_ofm_write) @@ -396,7 +398,9 @@ def calc_allowed_ofm_ifm_overlap_for_pass_list(strat, passes, block_configs): min_overlap = min(min_overlap, can_overwrite) if cmd.is_first: - ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.end_coord, is_top_box=True) + ifm_read = cmd.ifm_tensor.address_offset_for_coordinate( + cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True + ) min_overlap = max(min_overlap, 0) return min_overlap |