From 53d4752c8dbd9152238abd854ec23e87e5123881 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 4 May 2020 11:32:05 +0200 Subject: MLBEDSW-1649: Add size splits for Split op The tensor is split into len(size_splits) along the dimension axis with the sizes specified in the size_splits array. Change-Id: I2ce98fa10e2e26f16cfd86a775aee94a308509ea Signed-off-by: Charles Xu --- ethosu/vela/operation.py | 21 ++++++++++++++------- ethosu/vela/supported_operators.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 3c776dca..e28adef6 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -194,7 +194,7 @@ input and output tensors, as well as an attribute dictionary.""" return inputs, axis - split_ops = set(("Split", "StridedSlice", "Slice", "UnpackReshaped")) + split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")) def is_split_op(self): return self.type in Operation.split_ops @@ -206,12 +206,6 @@ input and output tensors, as well as an attribute dictionary.""" offset_end = None axis = None if self.type == "Split": - # TODO: Extend split capabilities - # If num_or_size_splits is an integer, then value is split along dimension axis into num_split smaller - # tensors. This requires that num_split evenly divides value.shape[axis]. - # If num_or_size_splits is a 1-D Tensor (or list), we call it size_splits and value is split into - # len(size_splits) elements. The shape of the i-th element has the same size as the value except along - # dimension axis where the size is size_splits[i]. num_splits = self.attrs.get("num_splits") axis_tens = self.inputs[0] assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" @@ -220,6 +214,19 @@ input and output tensors, as well as an attribute dictionary.""" outputs = self.outputs assert num_splits == len(outputs) + if self.type == "SplitV": + num_splits = self.attrs.get("num_splits") + input_tens = self.inputs[0] + size_tens = self.inputs[1] + assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const" + sizes = size_tens.values + axis_tens = self.inputs[2] + assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" + axis = int(axis_tens.values) + outputs = self.outputs + assert num_splits == len(outputs) + assert sum(sizes) == input_tens.shape[axis] + elif self.type == "Slice": input_tens, begin_tens, size_tens = self.inputs outputs = self.outputs diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 70700e71..e5271450 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -73,7 +73,7 @@ class SupportedOperators: # bias add and batch norm | set(("QuantizedBiasAdd", "Requantize", "QuantizedBatchNorm", "BiasAdd", "FusedBatchNorm")) ) - self.split_ops = set(("Split", "StridedSlice", "Slice", "UnpackReshaped", "Unpack")) + self.split_ops = set(("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped", "Unpack")) self.concat_ops = set(("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped", "Pack")) self.memory_only_ops = ( set(("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims")) | self.concat_ops | self.split_ops -- cgit v1.2.1