aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index fd32b655..3be2898c 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -21,6 +21,7 @@ import math
from collections import defaultdict
from enum import Enum
from enum import IntEnum
+from typing import cast
from typing import Dict
from typing import List
from typing import Optional
@@ -319,7 +320,7 @@ def generate_activation(emit: CommandStreamEmitter, activation: Optional[NpuActi
quantized_min = max(-128, quantized_min)
quantized_max = min(127, quantized_max)
else:
- activation_value = activation_op_map[act.op_type]
+ activation_value = cast(int, activation_op_map[act.op_type])
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation_value)
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, quantized_min)
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MAX, quantized_max)
@@ -584,7 +585,7 @@ def get_arch_block_config(
block_config,
arch,
block_type,
- npu_op.ofm.shape,
+ shape3d_to_block(npu_op.ofm.shape),
ifm_shape,
ifm2_shape,
uses_scalar,
@@ -741,6 +742,8 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
else: # Add/Sub
+ opa_scale: float
+ opb_scale: float
bitdepth = npu_op.ifm.data_type.size_in_bits()
use_advanced_scaling = False
if None in (input_scale, input2_scale, output_scale):
@@ -799,7 +802,7 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
# -------------------------------------------------------------------
-def print_feature_map(fm: NpuFeatureMap, name: str):
+def print_feature_map(fm: Optional[NpuFeatureMap], name: str):
if fm is not None:
q = (
"no quantization"