diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 5ed18621..db1c6f18 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -834,7 +834,16 @@ class Operation: self.ifm_shapes = [] self.ofm_shapes = [] - ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm() + ifm_tensor, ifm2_tensor, ofm_tensor = self.get_ifm_ifm2_ofm() + + if self.type == Op.Reshape: + # Set ofm shape + if len(self.inputs) > 1 and self.inputs[1].values is not None: + ofm_tensor.shape = self.inputs[1].values.flatten().tolist() + ofm_elements = ofm_tensor.elements() + # Stretch dimension + if ofm_elements < 0: + ofm_tensor.shape[ofm_tensor.shape.index(-1)] = int(ifm_tensor.elements() / abs(ofm_elements)) # set all shapes to op, as 4D if self.type == Op.FullyConnected: @@ -847,7 +856,7 @@ class Operation: self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]])) else: self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) - if self.type == Op.Softmax: + elif self.type == Op.Softmax: self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) elif self.type.is_split_op() or self.type.is_concat_op(): |