aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-01 16:02:29 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-18 16:33:32 +0100
commit2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch)
treeb5151d0f12428e47d64b1fb2ce4f2f8c19304a0d /ethosu/vela/operation.py
parent528a56df829b65f7a2c61953650b123c461095f7 (diff)
downloadethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op Changed to rely on these shapes Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py50
1 files changed, 45 insertions, 5 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 30c32acc..be26a26b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING
from .errors import VelaError
from .numeric_util import full_shape
+
if TYPE_CHECKING:
from .tensor import Tensor
@@ -129,7 +130,7 @@ class Op(Enum):
Concat = OperatorInfo(indices=CONCAT_INDICES)
ConcatEmbeddings = OperatorInfo()
ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
- ConcatTFLite = OperatorInfo()
+ ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
@@ -197,7 +198,7 @@ class Op(Enum):
NonMaxSuppressionV5 = OperatorInfo()
NotEqual = OperatorInfo()
OneHot = OperatorInfo()
- Pack = OperatorInfo()
+ Pack = OperatorInfo(indices=IFM_INDICES)
PackReshaped = OperatorInfo(indices=IFM_INDICES)
Pad = OperatorInfo()
PadV2 = OperatorInfo()
@@ -260,7 +261,7 @@ class Op(Enum):
UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
Unique = OperatorInfo()
- Unpack = OperatorInfo()
+ Unpack = OperatorInfo(indices=IFM_INDICES)
UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
Where = OperatorInfo()
While = OperatorInfo()
@@ -305,14 +306,17 @@ class Op(Enum):
return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
def is_split_op(self):
- return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+ return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
def is_concat_op(self):
- return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+ return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
def needs_bias(self):
return bool(self.info.indices.biases)
+ def needs_shapes(self):
+ return bool(self.info.indices.ifms)
+
@classmethod
def op_set(cls, predicate):
# Returns the set of all operator codes that fulfill the given predicate
@@ -400,6 +404,8 @@ class Operation:
"forced_output_quantization",
"activation_lut",
"_kernel",
+ "ifm_shapes",
+ "ofm_shapes",
)
def __init__(self, op_type: Op, name: str):
@@ -421,6 +427,8 @@ class Operation:
self.op_index = None # input network operator index
self.activation_lut = None
self._kernel = None
+ self.ifm_shapes = []
+ self.ofm_shapes = []
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -697,3 +705,35 @@ class Operation:
lines += _print_tensors(self.outputs)
raise VelaError("\n".join(lines))
+
+ def set_ifm_ofm_shapes(self):
+ ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
+
+ # set all shapes to op, as 4D
+ if self.type == Op.FullyConnected:
+ n_in_elems = weight_tensor.shape[-2]
+ elms = ifm_tensor.elements()
+ batch_size = elms // n_in_elems
+ assert batch_size * n_in_elems == elms
+
+ self.ifm_shapes.append([batch_size, 1, 1, n_in_elems])
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
+ elif self.type == Op.Softmax:
+ self.ifm_shapes.append(ifm_tensor.get_full_shape())
+ self.ofm_shapes.append(ofm_tensor.get_full_shape())
+ elif self.type.is_split_op or self.type.is_concat_op():
+ for inp in self.inputs:
+ if inp is not None:
+ self.ifm_shapes.append(full_shape(4, inp.shape, 1))
+ else:
+ self.ifm_shapes.append(None)
+ for out in self.outputs:
+ if out is not None:
+ self.ofm_shapes.append(full_shape(4, out.shape, 1))
+ else:
+ self.ofm_shapes.append(None)
+ else:
+ self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1))
+ if ifm2_tensor is not None:
+ self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1))
+ self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1))