aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index e1341d82..82790364 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -27,6 +27,7 @@ from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .operation import Op
+from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
from .tensor import create_const_tensor
from .tensor import QuantizationParameters
@@ -101,6 +102,10 @@ def check_format_restrictions(tens: Tensor, arch):
):
return
+ # Writing to the buffer of a variable tensor needs to be linear format
+ if tens.ops[0].memory_function == Op.VariableTensorWrite:
+ return
+
# Check if any of the producers/consumers is run on CPU
if not all(cons.run_on_npu for cons in tens.consumer_list):
return
@@ -222,7 +227,8 @@ def move_splitsliceread_to_consumer(op, cons_op):
cons_op.ifm_shapes[1] = op.ifm_shapes[0]
op.ofm.consumer_list.remove(cons_op)
op.ofm.ops = []
- op.ifm.consumer_list.remove(op)
+ if op in op.ifm.consumer_list:
+ op.ifm.consumer_list.remove(op)
def check_memory_only_removed(op, arch):
@@ -357,3 +363,20 @@ def convert_to_lut(op, lut_values, lut_name):
op.set_ifm_ofm_shapes()
DebugDatabase.add_optimised(op, op)
return op
+
+
+def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+ """Creates an average pool for the given concat op/input feature map"""
+ ofm = concat_op.ofm
+ avgpool_op = create_avgpool_nop(name)
+ avgpool_op.inputs = [ifm]
+ avgpool_op.outputs = [ofm]
+
+ avgpool_op.write_offset = write_offset
+ avgpool_op.write_shape = ifm_shape
+ ofm.ops.append(avgpool_op)
+ avgpool_op.ifm_shapes.append(ifm_shape)
+ avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+ avgpool_op.memory_function = Op.ConcatSliceWrite
+ DebugDatabase.add_optimised(concat_op, avgpool_op)
+ return avgpool_op