aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 518b6db0..e0c7fd2c 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -41,6 +41,7 @@ from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .numeric_util import clamp_sigmoid
+from .numeric_util import full_shape
from .numeric_util import round_away_zero
from .operation import create_activation_function
from .operation import ExplicitScaling
@@ -524,7 +525,7 @@ def convert_argmax_to_depthwise_conv_and_max_pool(op, arch, nng):
op.name = "depthwise_conv_SHL_7"
op.type = Op.DepthwiseConv2DBias
op.attrs.update(dw_op_attrs)
- n, h, w, c = ifm.shape
+ n, h, w, c = full_shape(4, ifm.shape, 1)
shape = [1, 1, 1, c]
kernel = np.dstack([2**7] * c)
op.inputs = []