diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-01-29 11:51:31 +0100 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-01-29 16:05:03 +0100 |
commit | 2c2522dd44229a03d3d778cd239478fedc19ee57 (patch) | |
tree | 610bd611f9783f71cf79f4c2e8466789cacfd429 /ethosu/vela/operation.py | |
parent | 7bada4039d01836c995a12251034777055e1848a (diff) | |
download | ethos-u-vela-2c2522dd44229a03d3d778cd239478fedc19ee57.tar.gz |
MLBEDSW-3772 Fix FC with changed inp shape
When FC input is fixed by changing ifm_shape,
avoid_NHCWB16 must be set to ifm.
-Fixed issue with ResizeBilinear
-Changed to post order for concat ops in graph optimisation
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ie0c6a86637c210c0833ae9b2f8e7c494c5d4f66e
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 342efd9d..8d54d658 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -719,14 +719,16 @@ class Operation: # set all shapes to op, as 4D if self.type == Op.FullyConnected: - n_in_elems = weight_tensor.shape[-2] - elms = ifm_tensor.elements() - batch_size = elms // n_in_elems - assert batch_size * n_in_elems == elms - - self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems])) - self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) - elif self.type == Op.Softmax: + if len(self.ifm.shape) == 2: + self.ifm_shapes.append(Shape4D([self.ifm.shape[0], 1, 1, self.ifm.shape[1]])) + else: + # Special case, handled in graph optimization + self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) + if len(self.ofm.shape) == 2: + self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]])) + else: + self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) + if self.type == Op.Softmax: self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) elif self.type.is_split_op or self.type.is_concat_op(): |