diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index e1341d82..82790364 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -27,6 +27,7 @@ from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .errors import VelaError from .operation import Op +from .operation_util import create_avgpool_nop from .shape4d import Shape4D from .tensor import create_const_tensor from .tensor import QuantizationParameters @@ -101,6 +102,10 @@ def check_format_restrictions(tens: Tensor, arch): ): return + # Writing to the buffer of a variable tensor needs to be linear format + if tens.ops[0].memory_function == Op.VariableTensorWrite: + return + # Check if any of the producers/consumers is run on CPU if not all(cons.run_on_npu for cons in tens.consumer_list): return @@ -222,7 +227,8 @@ def move_splitsliceread_to_consumer(op, cons_op): cons_op.ifm_shapes[1] = op.ifm_shapes[0] op.ofm.consumer_list.remove(cons_op) op.ofm.ops = [] - op.ifm.consumer_list.remove(op) + if op in op.ifm.consumer_list: + op.ifm.consumer_list.remove(op) def check_memory_only_removed(op, arch): @@ -357,3 +363,20 @@ def convert_to_lut(op, lut_values, lut_name): op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, op) return op + + +def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D): + """Creates an average pool for the given concat op/input feature map""" + ofm = concat_op.ofm + avgpool_op = create_avgpool_nop(name) + avgpool_op.inputs = [ifm] + avgpool_op.outputs = [ofm] + + avgpool_op.write_offset = write_offset + avgpool_op.write_shape = ifm_shape + ofm.ops.append(avgpool_op) + avgpool_op.ifm_shapes.append(ifm_shape) + avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0]) + avgpool_op.memory_function = Op.ConcatSliceWrite + DebugDatabase.add_optimised(concat_op, avgpool_op) + return avgpool_op |