diff options
Diffstat (limited to 'ethosu/vela/register_command_stream_util.py')
-rw-r--r-- | ethosu/vela/register_command_stream_util.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py index 74c4f90e..8a6f94e4 100644 --- a/ethosu/vela/register_command_stream_util.py +++ b/ethosu/vela/register_command_stream_util.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -21,6 +21,7 @@ from typing import NamedTuple from typing import Optional from . import numeric_util +from .api import NpuAccumulatorType from .api import NpuActivationOp from .api import NpuAddressRange from .api import NpuBlockOperation @@ -42,6 +43,7 @@ from .errors import ByteSizeError from .operation import Kernel from .operation import PointXYZ from .tensor import TensorFormat +from .tflite.TensorType import TensorType from ethosu.vela.range_set import AccessDirection from ethosu.vela.range_set import MemoryAccessSet from ethosu.vela.range_set import MemoryRangeSet @@ -74,6 +76,15 @@ def check_length(length, required_multiple): check_size(length, required_multiple, "length") +def to_npu_acc_type(accType: TensorType) -> NpuAccumulatorType: + if accType == TensorType.INT32: + return NpuAccumulatorType.Int32 + elif accType == TensorType.INT64: + return NpuAccumulatorType.Int40 + else: + return NpuAccumulatorType.Default + + def to_npu_kernel(kernel: Kernel) -> NpuKernel: """Converts the given internally used kernel object to NpuKernel (of public API)""" return NpuKernel( |