aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_util.py')
-rw-r--r--ethosu/vela/register_command_stream_util.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index 74c4f90..8a6f94e 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -21,6 +21,7 @@ from typing import NamedTuple
from typing import Optional
from . import numeric_util
+from .api import NpuAccumulatorType
from .api import NpuActivationOp
from .api import NpuAddressRange
from .api import NpuBlockOperation
@@ -42,6 +43,7 @@ from .errors import ByteSizeError
from .operation import Kernel
from .operation import PointXYZ
from .tensor import TensorFormat
+from .tflite.TensorType import TensorType
from ethosu.vela.range_set import AccessDirection
from ethosu.vela.range_set import MemoryAccessSet
from ethosu.vela.range_set import MemoryRangeSet
@@ -74,6 +76,15 @@ def check_length(length, required_multiple):
check_size(length, required_multiple, "length")
+def to_npu_acc_type(accType: TensorType) -> NpuAccumulatorType:
+ if accType == TensorType.INT32:
+ return NpuAccumulatorType.Int32
+ elif accType == TensorType.INT64:
+ return NpuAccumulatorType.Int40
+ else:
+ return NpuAccumulatorType.Default
+
+
def to_npu_kernel(kernel: Kernel) -> NpuKernel:
"""Converts the given internally used kernel object to NpuKernel (of public API)"""
return NpuKernel(