aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_to_npu_op.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 6c403c86..f7c91aa2 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -17,9 +17,11 @@
# Description:
# Conversion from high level command to NpuOperation
from enum import IntEnum
+from typing import cast
from typing import Dict
from typing import List
from typing import Optional
+from typing import Tuple
from .api import NpuActivation
from .api import NpuActivationOp
@@ -66,6 +68,7 @@ from .tensor import Tensor
from .tensor import TensorFormat
from .tensor import TensorPurpose
from .tensor import TensorSubPurpose
+from .weight_compressor import NpuWeightTensor
from .weight_compressor import WeightKey
@@ -294,17 +297,17 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_sh
def create_weights(
- weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures
-) -> List[NpuAddressRange]:
+ weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures
+) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]:
"""Returns address ranges for weights and scales"""
weights = []
biases = []
shared_region = get_region(weight_tensor.mem_type, arch)
- scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch)
+ scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0
w_tensor_src = weight_tensor
if weight_tensor.src_tensor:
- w_tensor_src = weight_tensor.src_tensor
+ w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor)
core_offset = 0
for core in range(0, arch.ncores):