diff options
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 09348811..4a9b0719 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -277,10 +277,10 @@ def has_prev_op_dependency(prev_cmd, cmd): if prev_cmd is None: return False if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps): - if prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id: + if prev_cmd.ofm_tensor.equivalent(cmd.ifm_tensor): return True elif cmd.ifm2_tensor is not None: - return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm2_tensor.equivalence_id + return prev_cmd.ofm_tensor.equivalent(cmd.ifm2_tensor) return False @@ -560,12 +560,13 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): else: emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0) - # For elementwise set the required SHRAM to be equal to the total size of SHRAM - shram_required = arch.shram_total_banks + # For elementwise set the required SHRAM to be equal to the total size of available SHRAM + uses_lut = primary_op.activation_lut is not None + shram_required = arch.available_shram_banks(uses_lut) emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required) # Acc buffers not needed so set AB_START to size of SHRAM - emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks) + emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required) # Is not a unary operator if cmd.ifm2_tensor is not None: @@ -852,8 +853,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point) faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point) elif faf == "LUT": - lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0) - assert lut_index <= activation.LUT_END.value, "LUT index out of range." + lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1) + assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range." emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index) faf_min = ofm_quant_qmin faf_max = ofm_quant_qmax |