From da2b0030220e87788573a724979626aa92afd13e Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 4 Feb 2021 16:28:29 +0100 Subject: MLBEDSW-3937 Fix check for NHCWB16 for FC Fix check for NHCWB16 for modifying FC input. Signed-off-by: Patrik Gustavsson Change-Id: Ie50c32ca079afadd0af9b7b909820794ceee373c --- ethosu/vela/graph_optimiser.py | 6 ++---- ethosu/vela/operation.py | 8 +++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 1e3b1314..5c1b90bd 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -251,7 +251,6 @@ def fixup_conv2d_backprop(op, arch, nng): if op.type == Op.Conv2DBackpropInput: # flip the inputs op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0] - op.set_ifm_ofm_shapes() op.type = Op.Conv2DBackpropInputSwitchedBias op.ifm.resampling_mode = resampling_mode.TRANSPOSE @@ -370,10 +369,9 @@ def rewrite_fully_connected_input(op, arch, nng): batch_size = elms // n_in_elems assert batch_size * n_in_elems == elms - if op.ifm.shape != [batch_size, n_in_elems]: - op.ifm.avoid_NHCWB16 = True - op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems]) + if Shape4D(op.ifm.shape) != op.ifm_shapes[0]: + op.ifm.avoid_NHCWB16 = True return op diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 09371b7a..b297bed0 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -739,7 +739,7 @@ class Operation: if self.type == Op.Softmax: self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) - elif self.type.is_split_op or self.type.is_concat_op(): + elif self.type.is_split_op() or self.type.is_concat_op(): for inp in self.inputs: if inp is not None: self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1))) @@ -751,7 +751,9 @@ class Operation: else: self.ofm_shapes.append(None) else: - self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1))) + if ifm_tensor is not None: + self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1))) if ifm2_tensor is not None: self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1))) - self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1))) + if ofm_tensor is not None: + self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1))) -- cgit v1.2.1