aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 5ed18621..db1c6f18 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -834,7 +834,16 @@ class Operation:
self.ifm_shapes = []
self.ofm_shapes = []
- ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
+ ifm_tensor, ifm2_tensor, ofm_tensor = self.get_ifm_ifm2_ofm()
+
+ if self.type == Op.Reshape:
+ # Set ofm shape
+ if len(self.inputs) > 1 and self.inputs[1].values is not None:
+ ofm_tensor.shape = self.inputs[1].values.flatten().tolist()
+ ofm_elements = ofm_tensor.elements()
+ # Stretch dimension
+ if ofm_elements < 0:
+ ofm_tensor.shape[ofm_tensor.shape.index(-1)] = int(ifm_tensor.elements() / abs(ofm_elements))
# set all shapes to op, as 4D
if self.type == Op.FullyConnected:
@@ -847,7 +856,7 @@ class Operation:
self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
else:
self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
- if self.type == Op.Softmax:
+ elif self.type == Op.Softmax:
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
elif self.type.is_split_op() or self.type.is_concat_op():