diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index ca8b89fc..8a393a2e 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -24,10 +24,10 @@ from . import rewrite_graph from .data_type import DataType from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode +from .numeric_util import full_shape from .operation import NpuBlockType from .operation import Operation from .tensor import Tensor -from .numeric_util import full_shape passthrough_nodes = set(("Identity",)) @@ -448,17 +448,27 @@ def fixup_act_reorder(op, arch): op.type = "Identity" return op + def fixup_elementwise_with_scalars(op, arch): if op.type in binary_elementwise_op: - ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm() + ifm_tensor, ifm2_tensor, _, _ = op.get_ifm_ifm2_weights_ofm() if ifm2_tensor.shape != [] and ifm_tensor.shape != []: diff = len(ifm_tensor.shape) - len(ifm2_tensor.shape) if diff > 0: ifm2_tensor.shape = full_shape(len(ifm_tensor.shape), ifm2_tensor.shape, 1) elif diff < 0: ifm_tensor.shape = full_shape(len(ifm2_tensor.shape), ifm_tensor.shape, 1) + elif ifm_tensor.shape == [] and ifm_tensor.quant_values is None: + # IFM is marked as a scalar, but is a result of an operation; change it to a shape of size 1 + ifm_tensor.shape = len(ifm2_tensor.shape) * [1] + ifm_tensor.storage_shape = ifm_tensor.shape + elif ifm2_tensor.shape == [] and ifm2_tensor.quant_values is None: + # IFM2 is marked as a scalar, but is a result of an operation; change it to a shape of size 1 + ifm2_tensor.shape = len(ifm_tensor.shape) * [1] + ifm2_tensor.storage_shape = ifm2_tensor.shape return op + # Set input/output tensor equivalence to the same id for memory operations def set_tensor_equivalence(op, arch): if op.type == "Reshape": |