From 2446e59a9083f36f85beb88fdec6379d90b85cad Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 11 Feb 2021 08:36:12 +0100 Subject: MLBEDSW-3774 Fix avoid cascading for spilling Fix avoid cascading for spilling. Signed-off-by: Patrik Gustavsson Change-Id: If86189bd1566eaa14387dfc2c02e3324ea6c184e --- ethosu/vela/graph_optimiser.py | 1 + ethosu/vela/high_level_command_stream_generator.py | 1 - ethosu/vela/scheduler.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 50368b86..f5006c6f 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -109,6 +109,7 @@ def rewrite_concat_ops(op, arch): DebugDatabase.add_optimised(op, avgpool_op) avgpool_op.ifm_shapes.append(op.ifm_shapes[idx]) avgpool_op.ofm_shapes.append(op.ofm_shapes[0]) + avgpool_op.memory_function = Op.ConcatSliceWrite assert ofm.shape[axis] == offset # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index c2027e0f..1ce7e7e3 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -91,7 +91,6 @@ def generate_high_level_command_stream_for_pass(strat, passes, block_configs, id ofm_start[concat_axis] = concat_start ofm_end[concat_axis] = concat_end concat_offset = concat_start - ps.primary_op.memory_function = Op.ConcatSliceWrite elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid): ps.primary_op.activation = create_activation_function(op.type) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 90b89421..9251623c 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -641,7 +641,7 @@ class DynamicProgrammingScheduler: def avoid_for_cascading(self, pred_candidate): for op in pred_candidate.ops: if ( - op.type == Op.ConcatSliceWrite + op.memory_function == Op.ConcatSliceWrite and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area ): # For SRAM spilling, concat op is avoided as predecessor -- cgit v1.2.1