diff options
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r-- | ethosu/vela/high_level_command_to_npu_op.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 6c403c86..f7c91aa2 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -17,9 +17,11 @@ # Description: # Conversion from high level command to NpuOperation from enum import IntEnum +from typing import cast from typing import Dict from typing import List from typing import Optional +from typing import Tuple from .api import NpuActivation from .api import NpuActivationOp @@ -66,6 +68,7 @@ from .tensor import Tensor from .tensor import TensorFormat from .tensor import TensorPurpose from .tensor import TensorSubPurpose +from .weight_compressor import NpuWeightTensor from .weight_compressor import WeightKey @@ -294,17 +297,17 @@ def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, op_sh def create_weights( - weight_tensor: Tensor, weight_box: Box, scale_tensor: Tensor, arch: ArchitectureFeatures -) -> List[NpuAddressRange]: + weight_tensor: NpuWeightTensor, weight_box: Box, scale_tensor: NpuWeightTensor, arch: ArchitectureFeatures +) -> Tuple[List[NpuAddressRange], List[NpuAddressRange]]: """Returns address ranges for weights and scales""" weights = [] biases = [] shared_region = get_region(weight_tensor.mem_type, arch) - scale_region = scale_tensor and get_region(scale_tensor.mem_type, arch) + scale_region = get_region(scale_tensor.mem_type, arch) if scale_tensor else 0 w_tensor_src = weight_tensor if weight_tensor.src_tensor: - w_tensor_src = weight_tensor.src_tensor + w_tensor_src = cast(NpuWeightTensor, weight_tensor.src_tensor) core_offset = 0 for core in range(0, arch.ncores): |