diff options
author | Charles Xu <charles.xu@arm.com> | 2020-05-04 11:32:05 +0200 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | 53d4752c8dbd9152238abd854ec23e87e5123881 (patch) | |
tree | f52f6d781d8814bee73f18d7738ac6c49ae41022 /ethosu/vela/operation.py | |
parent | 22df2ade766815fadb44addf46b5d78b81787b9d (diff) | |
download | ethos-u-vela-53d4752c8dbd9152238abd854ec23e87e5123881.tar.gz |
MLBEDSW-1649: Add size splits for Split op
The tensor is split into len(size_splits) along the dimension
axis with the sizes specified in the size_splits array.
Change-Id: I2ce98fa10e2e26f16cfd86a775aee94a308509ea
Signed-off-by: Charles Xu <charles.xu@arm.com>
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 3c776dca..e28adef6 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -194,7 +194,7 @@ input and output tensors, as well as an attribute dictionary.""" return inputs, axis - split_ops = set(("Split", "StridedSlice", "Slice", "UnpackReshaped")) + split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")) def is_split_op(self): return self.type in Operation.split_ops @@ -206,12 +206,6 @@ input and output tensors, as well as an attribute dictionary.""" offset_end = None axis = None if self.type == "Split": - # TODO: Extend split capabilities - # If num_or_size_splits is an integer, then value is split along dimension axis into num_split smaller - # tensors. This requires that num_split evenly divides value.shape[axis]. - # If num_or_size_splits is a 1-D Tensor (or list), we call it size_splits and value is split into - # len(size_splits) elements. The shape of the i-th element has the same size as the value except along - # dimension axis where the size is size_splits[i]. num_splits = self.attrs.get("num_splits") axis_tens = self.inputs[0] assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" @@ -220,6 +214,19 @@ input and output tensors, as well as an attribute dictionary.""" outputs = self.outputs assert num_splits == len(outputs) + if self.type == "SplitV": + num_splits = self.attrs.get("num_splits") + input_tens = self.inputs[0] + size_tens = self.inputs[1] + assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const" + sizes = size_tens.values + axis_tens = self.inputs[2] + assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" + axis = int(axis_tens.values) + outputs = self.outputs + assert num_splits == len(outputs) + assert sum(sizes) == input_tens.shape[axis] + elif self.type == "Slice": input_tens, begin_tens, size_tens = self.inputs outputs = self.outputs |