aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 68817035..22fe512e 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -17,6 +17,8 @@
# Compresses and pads the weigths. It also calculates the scales and packs with the biases.
from collections import namedtuple
from collections import OrderedDict
+from typing import Dict
+from typing import Optional
from typing import Tuple
import numpy as np
@@ -75,7 +77,7 @@ class NpuWeightTensor(Tensor):
class CompressedWeightCache:
"""Global tensor weight compression cache"""
- cache = {}
+ cache: Dict[WeightCompressionConfig, Tensor] = {}
@staticmethod
def get_tensor_with_same_compression(wcc):
@@ -279,7 +281,7 @@ def _prepare_scale_and_bias(arch, tens, rescale_for_faf, explicit_scaling):
def encode_weight_and_scale_tensor(
arch, op, weight_tens, scale_tens, kernel, block_config, depth_offsets, rescale_for_faf=False
-) -> (NpuWeightTensor, NpuWeightTensor):
+) -> Tuple[Optional[NpuWeightTensor], Optional[NpuWeightTensor]]:
npu_block_type = op.type.npu_block_type
ifm_scale = scale_tens and scale_tens.consumer_list[0].get_input_quantization().scale_f32