diff options
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 5fa71aa7..a2b2f4d0 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -374,10 +374,6 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, param, absolute_dep[CommandType.DMA][0]) prev_cmd = None # Clear any dependency - # Start by issuing REGION commands since they remain the same - emit.cmd0_with_param(cmd0.NPU_SET_IFM_REGION, BasePointerIndex.Scratch) - emit.cmd0_with_param(cmd0.NPU_SET_IFM2_REGION, BasePointerIndex.Scratch) - emit.cmd0_with_param(cmd0.NPU_SET_OFM_REGION, BasePointerIndex.Scratch) for cmd in cmd_stream: if cmd.cmdtype == CommandType.DMA: start_coord = cmd.box.start_coord @@ -730,10 +726,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): else: emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, out_shape[-1] - 1) - for tens, box, ptr_ops, stride_ops, zero_point_op in ( + for tens, box, region_op, ptr_ops, stride_ops, zero_point_op in ( ( cmd.ifm_tensor, cmd.ifm_box, + cmd0.NPU_SET_IFM_REGION, (cmd1.NPU_SET_IFM_BASE0, cmd1.NPU_SET_IFM_BASE1, cmd1.NPU_SET_IFM_BASE2, cmd1.NPU_SET_IFM_BASE3), (cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X), cmd0.NPU_SET_IFM_ZERO_POINT, @@ -741,6 +738,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): ( cmd.ifm2_tensor, cmd.ifm2_box, + cmd0.NPU_SET_IFM2_REGION, ( cmd1.NPU_SET_IFM2_BASE0, cmd1.NPU_SET_IFM2_BASE1, @@ -753,6 +751,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): ( cmd.ofm_tensor, cmd.ofm_box, + cmd0.NPU_SET_OFM_REGION, (cmd1.NPU_SET_OFM_BASE0, cmd1.NPU_SET_OFM_BASE1, cmd1.NPU_SET_OFM_BASE2, cmd1.NPU_SET_OFM_BASE3), (cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X), cmd0.NPU_SET_OFM_ZERO_POINT, @@ -807,6 +806,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): else: assert False + if tens.mem_area == MemArea.Sram: + emit.cmd0_with_param(region_op, BasePointerIndex.Scratch) + else: + emit.cmd0_with_param(region_op, BasePointerIndex.ReadOnly) + for idx, addr in enumerate(addresses): if addr is None: addresses[idx] = 0 |