aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 8095f082..dafd2849 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -26,11 +26,12 @@ from .errors import VelaError
from .operation import Op
from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
-from .tensor import check_quantized_tens_scaling_equal
memory_only_ops = (
Op.Reshape,
+ Op.QuantizedReshape,
Op.Squeeze,
+ Op.ExpandDims,
)
@@ -177,10 +178,11 @@ def set_ifm_ofm_op_shapes(op, arch, nng):
return op
-def bypass_reshape_and_squeeze_ops(op):
- assert op.type in (Op.Reshape, Op.Squeeze)
+def bypass_memory_only_ops(op):
+ assert op.type in memory_only_ops
ofm = op.ofm
ifm = op.ifm
+
# Check if ifm/ofm are network ifm/ofm
ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
@@ -235,13 +237,10 @@ def move_splitsliceread_to_consumer(op, cons_op):
op.ifm.consumer_list.remove(op)
-def check_reshapes(op, arch):
- if op.run_on_npu and op.type == Op.Reshape:
- ofm = op.ofm
-
- if check_quantized_tens_scaling_equal(op.ifm, ofm):
- # Reshape should have been removed
- raise VelaError(f"Reshape op {op} expected to have been removed, still remains")
+def check_memory_only_removed(op, arch):
+ if op.run_on_npu and op.type in memory_only_ops:
+ # Memory only operators should have been removed
+ raise VelaError(f"Memory only {op.type} op {op} expected to have been removed, still remains")
def record_optimised(op, arch):
@@ -271,10 +270,10 @@ def insert_copy_op_after_tens(tens):
def fix_sg_input_output(op, arch, nng):
- if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze):
+ if not op.run_on_npu or op.type not in memory_only_ops:
return op
- # For the Reshape/Squeeze operators we want to remove, tensors are removed.
+ # For the memory only operators we want to remove, tensors are removed.
# But in order to to do this, they cannot be outputs of the sg,
# this need to be fixed prior to the removal.
# Solution is to add a avgpool NOP, to maintain the original tensor.
@@ -290,7 +289,7 @@ def fix_sg_input_output(op, arch, nng):
ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
- # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze
+ # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator.
insert_copy_op_after_tens(op.ifm)
return op