aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-02-04 16:28:29 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-02-08 14:24:41 +0100
commitda2b0030220e87788573a724979626aa92afd13e (patch)
tree66b0400f90e6854129005e08232f7db94c64fd14
parent455e20e5ed0d5ce141a921e67f0219e55044e6e1 (diff)
downloadethos-u-vela-da2b0030220e87788573a724979626aa92afd13e.tar.gz
MLBEDSW-3937 Fix check for NHCWB16 for FC
Fix check for NHCWB16 for modifying FC input. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ie50c32ca079afadd0af9b7b909820794ceee373c
-rw-r--r--ethosu/vela/graph_optimiser.py6
-rw-r--r--ethosu/vela/operation.py8
2 files changed, 7 insertions, 7 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 1e3b131..5c1b90b 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -251,7 +251,6 @@ def fixup_conv2d_backprop(op, arch, nng):
if op.type == Op.Conv2DBackpropInput:
# flip the inputs
op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
- op.set_ifm_ofm_shapes()
op.type = Op.Conv2DBackpropInputSwitchedBias
op.ifm.resampling_mode = resampling_mode.TRANSPOSE
@@ -370,10 +369,9 @@ def rewrite_fully_connected_input(op, arch, nng):
batch_size = elms // n_in_elems
assert batch_size * n_in_elems == elms
- if op.ifm.shape != [batch_size, n_in_elems]:
- op.ifm.avoid_NHCWB16 = True
-
op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
+ if Shape4D(op.ifm.shape) != op.ifm_shapes[0]:
+ op.ifm.avoid_NHCWB16 = True
return op
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 09371b7..b297bed 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -739,7 +739,7 @@ class Operation:
if self.type == Op.Softmax:
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
- elif self.type.is_split_op or self.type.is_concat_op():
+ elif self.type.is_split_op() or self.type.is_concat_op():
for inp in self.inputs:
if inp is not None:
self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
@@ -751,7 +751,9 @@ class Operation:
else:
self.ofm_shapes.append(None)
else:
- self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
+ if ifm_tensor is not None:
+ self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
if ifm2_tensor is not None:
self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
- self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
+ if ofm_tensor is not None:
+ self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))