aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Xu <charles.xu@arm.com>2020-05-04 11:32:05 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit53d4752c8dbd9152238abd854ec23e87e5123881 (patch)
treef52f6d781d8814bee73f18d7738ac6c49ae41022
parent22df2ade766815fadb44addf46b5d78b81787b9d (diff)
downloadethos-u-vela-53d4752c8dbd9152238abd854ec23e87e5123881.tar.gz
MLBEDSW-1649: Add size splits for Split op
The tensor is split into len(size_splits) along the dimension axis with the sizes specified in the size_splits array. Change-Id: I2ce98fa10e2e26f16cfd86a775aee94a308509ea Signed-off-by: Charles Xu <charles.xu@arm.com>
-rw-r--r--ethosu/vela/operation.py21
-rw-r--r--ethosu/vela/supported_operators.py2
2 files changed, 15 insertions, 8 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 3c776dca..e28adef6 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -194,7 +194,7 @@ input and output tensors, as well as an attribute dictionary."""
return inputs, axis
- split_ops = set(("Split", "StridedSlice", "Slice", "UnpackReshaped"))
+ split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped"))
def is_split_op(self):
return self.type in Operation.split_ops
@@ -206,12 +206,6 @@ input and output tensors, as well as an attribute dictionary."""
offset_end = None
axis = None
if self.type == "Split":
- # TODO: Extend split capabilities
- # If num_or_size_splits is an integer, then value is split along dimension axis into num_split smaller
- # tensors. This requires that num_split evenly divides value.shape[axis].
- # If num_or_size_splits is a 1-D Tensor (or list), we call it size_splits and value is split into
- # len(size_splits) elements. The shape of the i-th element has the same size as the value except along
- # dimension axis where the size is size_splits[i].
num_splits = self.attrs.get("num_splits")
axis_tens = self.inputs[0]
assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
@@ -220,6 +214,19 @@ input and output tensors, as well as an attribute dictionary."""
outputs = self.outputs
assert num_splits == len(outputs)
+ if self.type == "SplitV":
+ num_splits = self.attrs.get("num_splits")
+ input_tens = self.inputs[0]
+ size_tens = self.inputs[1]
+ assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
+ sizes = size_tens.values
+ axis_tens = self.inputs[2]
+ assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+ axis = int(axis_tens.values)
+ outputs = self.outputs
+ assert num_splits == len(outputs)
+ assert sum(sizes) == input_tens.shape[axis]
+
elif self.type == "Slice":
input_tens, begin_tens, size_tens = self.inputs
outputs = self.outputs
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 70700e71..e5271450 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -73,7 +73,7 @@ class SupportedOperators:
# bias add and batch norm
| set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm"))
)
- self.split_ops = set(("Split", "StridedSlice", "Slice", "UnpackReshaped", "Unpack"))
+ self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack"))
self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack"))
self.memory_only_ops = (
set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")) | self.concat_ops | self.split_ops