aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 09348811..4a9b0719 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -277,10 +277,10 @@ def has_prev_op_dependency(prev_cmd, cmd):
if prev_cmd is None:
return False
if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps):
- if prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id:
+ if prev_cmd.ofm_tensor.equivalent(cmd.ifm_tensor):
return True
elif cmd.ifm2_tensor is not None:
- return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm2_tensor.equivalence_id
+ return prev_cmd.ofm_tensor.equivalent(cmd.ifm2_tensor)
return False
@@ -560,12 +560,13 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
else:
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0)
- # For elementwise set the required SHRAM to be equal to the total size of SHRAM
- shram_required = arch.shram_total_banks
+ # For elementwise set the required SHRAM to be equal to the total size of available SHRAM
+ uses_lut = primary_op.activation_lut is not None
+ shram_required = arch.available_shram_banks(uses_lut)
emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required)
# Acc buffers not needed so set AB_START to size of SHRAM
- emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks)
+ emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required)
# Is not a unary operator
if cmd.ifm2_tensor is not None:
@@ -852,8 +853,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
elif faf == "LUT":
- lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0)
- assert lut_index <= activation.LUT_END.value, "LUT index out of range."
+ lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
+ assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index)
faf_min = ofm_quant_qmin
faf_max = ofm_quant_qmax