aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 9d9a1e6..ec01d3e 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -31,6 +31,7 @@ import numpy as np
from . import scaling
from .api import NpuAccelerator
+from .api import NpuAccumulatorType
from .api import NpuActivation
from .api import NpuActivationOp
from .api import NpuAddressRange
@@ -270,6 +271,11 @@ acc_format_map = {
SHRAMElements.Acc40: acc_format.INT_40BIT.value,
}
+npu_acc_format_map = {
+ NpuAccumulatorType.Int32: acc_format.INT_32BIT.value,
+ NpuAccumulatorType.Int40: acc_format.INT_40BIT.value,
+}
+
resampling_mode_map = {
NpuResamplingMode.NONE: resampling_mode.NONE,
NpuResamplingMode.NEAREST: resampling_mode.NEAREST,
@@ -574,7 +580,10 @@ def generate_shram_registers(
emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch_block_config.layout.ab_start)
if has_ifm2(npu_op):
emit.cmd0_with_param(cmd0.NPU_SET_IFM2_IB_START, arch_block_config.layout.ib_start2)
- emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[arch_block_config.acc_type])
+ if npu_op.accumulator_type != NpuAccumulatorType.Default:
+ emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, npu_acc_format_map[npu_op.accumulator_type])
+ else:
+ emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[arch_block_config.acc_type])
def get_block_config_for_npu_op(