diff options
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 9d9a1e63..ec01d3ed 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -31,6 +31,7 @@ import numpy as np from . import scaling from .api import NpuAccelerator +from .api import NpuAccumulatorType from .api import NpuActivation from .api import NpuActivationOp from .api import NpuAddressRange @@ -270,6 +271,11 @@ acc_format_map = { SHRAMElements.Acc40: acc_format.INT_40BIT.value, } +npu_acc_format_map = { + NpuAccumulatorType.Int32: acc_format.INT_32BIT.value, + NpuAccumulatorType.Int40: acc_format.INT_40BIT.value, +} + resampling_mode_map = { NpuResamplingMode.NONE: resampling_mode.NONE, NpuResamplingMode.NEAREST: resampling_mode.NEAREST, @@ -574,7 +580,10 @@ def generate_shram_registers( emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch_block_config.layout.ab_start) if has_ifm2(npu_op): emit.cmd0_with_param(cmd0.NPU_SET_IFM2_IB_START, arch_block_config.layout.ib_start2) - emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[arch_block_config.acc_type]) + if npu_op.accumulator_type != NpuAccumulatorType.Default: + emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, npu_acc_format_map[npu_op.accumulator_type]) + else: + emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[arch_block_config.acc_type]) def get_block_config_for_npu_op( |