diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-01 16:02:29 +0100 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-18 16:33:32 +0100 |
commit | 2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch) | |
tree | b5151d0f12428e47d64b1fb2ce4f2f8c19304a0d /ethosu/vela/operation.py | |
parent | 528a56df829b65f7a2c61953650b123c461095f7 (diff) | |
download | ethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz |
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op
Changed to rely on these shapes
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 50 |
1 files changed, 45 insertions, 5 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 30c32acc..be26a26b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING from .errors import VelaError from .numeric_util import full_shape + if TYPE_CHECKING: from .tensor import Tensor @@ -129,7 +130,7 @@ class Op(Enum): Concat = OperatorInfo(indices=CONCAT_INDICES) ConcatEmbeddings = OperatorInfo() ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES) - ConcatTFLite = OperatorInfo() + ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES) Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES) @@ -197,7 +198,7 @@ class Op(Enum): NonMaxSuppressionV5 = OperatorInfo() NotEqual = OperatorInfo() OneHot = OperatorInfo() - Pack = OperatorInfo() + Pack = OperatorInfo(indices=IFM_INDICES) PackReshaped = OperatorInfo(indices=IFM_INDICES) Pad = OperatorInfo() PadV2 = OperatorInfo() @@ -260,7 +261,7 @@ class Op(Enum): UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) Unique = OperatorInfo() - Unpack = OperatorInfo() + Unpack = OperatorInfo(indices=IFM_INDICES) UnpackReshaped = OperatorInfo(indices=IFM_INDICES) Where = OperatorInfo() While = OperatorInfo() @@ -305,14 +306,17 @@ class Op(Enum): return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT) def is_split_op(self): - return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped) + return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack) def is_concat_op(self): - return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped) + return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack) def needs_bias(self): return bool(self.info.indices.biases) + def needs_shapes(self): + return bool(self.info.indices.ifms) + @classmethod def op_set(cls, predicate): # Returns the set of all operator codes that fulfill the given predicate @@ -400,6 +404,8 @@ class Operation: "forced_output_quantization", "activation_lut", "_kernel", + "ifm_shapes", + "ofm_shapes", ) def __init__(self, op_type: Op, name: str): @@ -421,6 +427,8 @@ class Operation: self.op_index = None # input network operator index self.activation_lut = None self._kernel = None + self.ifm_shapes = [] + self.ofm_shapes = [] def clone(self, suffix="_clone"): res = Operation(self.type, self.name + suffix) @@ -697,3 +705,35 @@ class Operation: lines += _print_tensors(self.outputs) raise VelaError("\n".join(lines)) + + def set_ifm_ofm_shapes(self): + ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm() + + # set all shapes to op, as 4D + if self.type == Op.FullyConnected: + n_in_elems = weight_tensor.shape[-2] + elms = ifm_tensor.elements() + batch_size = elms // n_in_elems + assert batch_size * n_in_elems == elms + + self.ifm_shapes.append([batch_size, 1, 1, n_in_elems]) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) + elif self.type == Op.Softmax: + self.ifm_shapes.append(ifm_tensor.get_full_shape()) + self.ofm_shapes.append(ofm_tensor.get_full_shape()) + elif self.type.is_split_op or self.type.is_concat_op(): + for inp in self.inputs: + if inp is not None: + self.ifm_shapes.append(full_shape(4, inp.shape, 1)) + else: + self.ifm_shapes.append(None) + for out in self.outputs: + if out is not None: + self.ofm_shapes.append(full_shape(4, out.shape, 1)) + else: + self.ofm_shapes.append(None) + else: + self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1)) + if ifm2_tensor is not None: + self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1)) + self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1)) |