aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2022-11-04 09:48:49 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-11-09 20:33:19 +0000
commitf3c7d557371b26835e3183064de58354f3a8b3cb (patch)
treeae6ab62e750f9b5c07ea4695380b4dfe06600999
parent9d51ec41855a8be21bd0708c882d121e5bb5afcc (diff)
downloadethos-u-vela-f3c7d557371b26835e3183064de58354f3a8b3cb.tar.gz
MLBEDSW-6881 SHAPE single op network is optimised to nothing3.6.0.rc1
Fixed by adding an operation to copy the statically optimised data to the subgraph output. Change-Id: Ica757e37d5460237973444ffd39c7d2850f319e3 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--ethosu/vela/live_range.py6
-rw-r--r--ethosu/vela/operation.py4
-rw-r--r--ethosu/vela/register_command_stream_generator.py17
-rw-r--r--ethosu/vela/register_command_stream_util.py12
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py21
5 files changed, 45 insertions, 15 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index fbb48ecd..2829f398 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -240,11 +240,11 @@ def extract_live_ranges_from_cascaded_passes(
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
- op_subgraph = cps.passes[0].ops[0].attrs.get("subgraph", None)
- op_type = cps.passes[0].ops[0].type
+ op = cps.passes[0].ops[0] if cps.passes[0].ops else None
+ op_subgraph = op.attrs.get("subgraph", None) if op else None
if op_subgraph is not None and MemType.Permanent_CPU not in target_mem_type_set:
- if op_type == Op.CustomNpuOp:
+ if op.type == Op.CustomNpuOp:
# If the primary-op is an NpuOp that means this is where an Npu subgraph
# is called. Go into said subgraph and extract live ranges before continuing.
# Use default allocation alignment of 16 for Npu tensors
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index b4d0e48a..05c236a5 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -596,6 +596,10 @@ class Operation:
def original_type(self):
return self._original_type
+ @property
+ def type_changed(self):
+ return self.type != self.original_type
+
def get_kernel_size(self):
weights = self.weights
if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 99ac32d5..c9b57f22 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -49,7 +49,6 @@ from .api import NpuOperationType
from .api import NpuPadding
from .api import NpuPoolingOp
from .api import NpuPoolingOperation
-from .api import NpuQuantization
from .api import NpuResamplingMode
from .api import NpuRoundingMode
from .api import NpuShape3D
@@ -69,7 +68,6 @@ from .ethos_u55_regs.ethos_u55_regs import elementwise_mode
from .ethos_u55_regs.ethos_u55_regs import pooling_mode
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .ethos_u55_regs.ethos_u55_regs import rounding
-from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
from .operation import ExplicitScaling
@@ -81,7 +79,9 @@ from .register_command_stream_util import get_dma_memory_accesses
from .register_command_stream_util import get_op_memory_accesses
from .register_command_stream_util import get_strides
from .register_command_stream_util import get_wait_dependency
+from .register_command_stream_util import get_zero_point
from .register_command_stream_util import has_ifm2
+from .register_command_stream_util import quantise
from .register_command_stream_util import shape3d_to_block
from .register_command_stream_util import to_kernel
from .register_command_stream_util import UNARY_ELEMWISE_OPS
@@ -298,13 +298,6 @@ def check_mem_limits(memory_accesses: MemoryAccessSet, mem_limits: Dict[int, int
)
-def quantise(value: float, quant: Optional[NpuQuantization]) -> int:
- """Quantizes the given value"""
- scale = 1 if quant is None or quant.scale_f32 is None else quant.scale_f32
- zp = 0 if quant is None else quant.zero_point
- return quantise_float32(value, scale, zp)
-
-
def generate_padding(emit: CommandStreamEmitter, padding: NpuPadding):
"""Generates IFM_PAD registers"""
emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, padding.top)
@@ -440,7 +433,7 @@ def generate_ifm(emit: CommandStreamEmitter, ifm: NpuFeatureMap):
)
emit.cmd0_with_param(cmd0.NPU_SET_IFM_DEPTH_M1, ifm.shape.depth - 1)
generate_strides(emit, ifm, cmd1.NPU_SET_IFM_STRIDE_C, cmd1.NPU_SET_IFM_STRIDE_Y, cmd1.NPU_SET_IFM_STRIDE_X)
- emit.cmd0_with_param(cmd0.NPU_SET_IFM_ZERO_POINT, int(ifm.quantization.zero_point))
+ emit.cmd0_with_param(cmd0.NPU_SET_IFM_ZERO_POINT, get_zero_point(ifm))
def generate_ifm2(emit: CommandStreamEmitter, ifm2: NpuFeatureMap, has_scalar: bool):
@@ -457,7 +450,7 @@ def generate_ifm2(emit: CommandStreamEmitter, ifm2: NpuFeatureMap, has_scalar: b
emit, [cmd0.NPU_SET_IFM2_HEIGHT0_M1, cmd0.NPU_SET_IFM2_HEIGHT1_M1, cmd0.NPU_SET_IFM2_WIDTH0_M1], ifm2.tiles
)
generate_strides(emit, ifm2, cmd1.NPU_SET_IFM2_STRIDE_C, cmd1.NPU_SET_IFM2_STRIDE_Y, cmd1.NPU_SET_IFM2_STRIDE_X)
- emit.cmd0_with_param(cmd0.NPU_SET_IFM2_ZERO_POINT, int(ifm2.quantization.zero_point))
+ emit.cmd0_with_param(cmd0.NPU_SET_IFM2_ZERO_POINT, get_zero_point(ifm2))
def generate_ofm(emit: CommandStreamEmitter, ofm: NpuFeatureMap):
@@ -476,7 +469,7 @@ def generate_ofm(emit: CommandStreamEmitter, ofm: NpuFeatureMap):
emit.cmd0_with_param(cmd0.NPU_SET_OFM_WIDTH_M1, ofm.shape.width - 1)
emit.cmd0_with_param(cmd0.NPU_SET_OFM_DEPTH_M1, ofm.shape.depth - 1)
generate_strides(emit, ofm, cmd1.NPU_SET_OFM_STRIDE_C, cmd1.NPU_SET_OFM_STRIDE_Y, cmd1.NPU_SET_OFM_STRIDE_X)
- emit.cmd0_with_param(cmd0.NPU_SET_OFM_ZERO_POINT, int(ofm.quantization.zero_point))
+ emit.cmd0_with_param(cmd0.NPU_SET_OFM_ZERO_POINT, get_zero_point(ofm))
def generate_kernel(emit: CommandStreamEmitter, kernel: NpuKernel, block_traversal: NpuBlockTraversal):
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index b2c84d7c..1b2cb47b 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -32,6 +32,7 @@ from .api import NpuLayout
from .api import NpuOperation
from .api import NpuOperationType
from .api import NpuPadding
+from .api import NpuQuantization
from .api import NpuShape3D
from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
@@ -80,6 +81,17 @@ def shape3d_to_block(shape: NpuShape3D) -> Block:
return Block(shape.width, shape.height, shape.depth)
+def get_zero_point(fm: NpuFeatureMap):
+ return int(fm.quantization.zero_point if fm.quantization else 0)
+
+
+def quantise(value: float, quant: Optional[NpuQuantization]) -> int:
+ """Quantizes the given value"""
+ scale = 1 if quant is None or quant.scale_f32 is None else quant.scale_f32
+ zp = 0 if quant is None else quant.zero_point
+ return numeric_util.quantise_float32(value, scale, zp)
+
+
# -------------------------------------------------------------------
# ADDRESSING/STRIDES (helper functions)
# -------------------------------------------------------------------
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index b8e61f48..90b29327 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -48,6 +48,7 @@ from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
from .operation import Padding
+from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .operation_util import get_pad_values_from_input
from .scaling import quantise_scale
@@ -1801,6 +1802,7 @@ def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
# Convert this SHAPE op to const
op.type = Op.Const
+ DebugDatabase.add_optimised(op, op)
# Add size calculation to shape output tensors
ofm.values = np.array(ifm.shape)
@@ -1935,4 +1937,23 @@ def tflite_optimise_graph(nng, arch):
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
sg.refresh_after_modification()
+ # Make sure that const optimisations on subgraph outputs are handled correctly
+ for sg in nng.subgraphs:
+ for ofm in sg.output_tensors:
+ if ofm.is_const and ofm.ops[0].type_changed:
+ # Subgraph output cannot be const - insert a memory copy
+ op = ofm.ops[0]
+ ofm_clone = ofm.clone()
+ ofm_clone.values = ofm.values
+ ofm.values = None
+ np_dtype = ofm.dtype.as_numpy_type()
+ zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
+ memcpy = create_add_nop(f"{ofm.name}_copy")
+ memcpy.add_input_tensor(ofm_clone)
+ memcpy.add_input_tensor(zero)
+ memcpy.set_output_tensor(ofm)
+ memcpy.set_ifm_ofm_shapes()
+ op.set_output_tensor(ofm_clone)
+ DebugDatabase.add_optimised(op, memcpy)
+
return nng