diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index 8095f082..dafd2849 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -26,11 +26,12 @@ from .errors import VelaError from .operation import Op from .operation_util import create_avgpool_nop from .shape4d import Shape4D -from .tensor import check_quantized_tens_scaling_equal memory_only_ops = ( Op.Reshape, + Op.QuantizedReshape, Op.Squeeze, + Op.ExpandDims, ) @@ -177,10 +178,11 @@ def set_ifm_ofm_op_shapes(op, arch, nng): return op -def bypass_reshape_and_squeeze_ops(op): - assert op.type in (Op.Reshape, Op.Squeeze) +def bypass_memory_only_ops(op): + assert op.type in memory_only_ops ofm = op.ofm ifm = op.ifm + # Check if ifm/ofm are network ifm/ofm ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list) @@ -235,13 +237,10 @@ def move_splitsliceread_to_consumer(op, cons_op): op.ifm.consumer_list.remove(op) -def check_reshapes(op, arch): - if op.run_on_npu and op.type == Op.Reshape: - ofm = op.ofm - - if check_quantized_tens_scaling_equal(op.ifm, ofm): - # Reshape should have been removed - raise VelaError(f"Reshape op {op} expected to have been removed, still remains") +def check_memory_only_removed(op, arch): + if op.run_on_npu and op.type in memory_only_ops: + # Memory only operators should have been removed + raise VelaError(f"Memory only {op.type} op {op} expected to have been removed, still remains") def record_optimised(op, arch): @@ -271,10 +270,10 @@ def insert_copy_op_after_tens(tens): def fix_sg_input_output(op, arch, nng): - if not op.run_on_npu or op.type not in (Op.Reshape, Op.Squeeze): + if not op.run_on_npu or op.type not in memory_only_ops: return op - # For the Reshape/Squeeze operators we want to remove, tensors are removed. + # For the memory only operators we want to remove, tensors are removed. # But in order to to do this, they cannot be outputs of the sg, # this need to be fixed prior to the removal. # Solution is to add a avgpool NOP, to maintain the original tensor. @@ -290,7 +289,7 @@ def fix_sg_input_output(op, arch, nng): ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list) if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed): - # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape/Squeeze + # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the memory only operator. insert_copy_op_after_tens(op.ifm) return op |