aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-06-30 09:07:16 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-07-09 09:51:44 +0200
commit5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093 (patch)
treedce92ab9d8a6ceb261c48353ff7077295efa21da
parent8f1f9aaa58175b17cd2e505bfcdb0e40c955ea72 (diff)
downloadethos-u-vela-5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093.tar.gz
MLBEDSW-4840 Move setting of input indices to tflite reader
Mapping to internal input indexing has been added to tflite_reader.py and tosa_reader.py. And the other way around in tflite_writer.py. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I4d8596e747cfa7c4203884c4e785eb1977e2bcc1
-rw-r--r--ethosu/vela/operation.py153
-rw-r--r--ethosu/vela/reader_util.py27
-rw-r--r--ethosu/vela/test/test_tflite_reader.py26
-rw-r--r--ethosu/vela/tflite_mapping.py335
-rw-r--r--ethosu/vela/tflite_reader.py8
-rw-r--r--ethosu/vela/tflite_writer.py13
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py4
-rw-r--r--ethosu/vela/tosa_mapping.py2
-rw-r--r--ethosu/vela/tosa_reader.py3
-rw-r--r--ethosu/vela/tosa_supported_operators.py2
10 files changed, 364 insertions, 209 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 0558e527..ffa4717d 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -100,16 +100,16 @@ class CustomType(Enum):
TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
-NO_INDICES = TensorIndices([], [], [])
-IFM_INDICES = TensorIndices([0], [], [])
-IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
-IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
-IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
-CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
-TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
-CONCAT_INDICES = TensorIndices([1, 2], [], [])
-SPLIT_IFM_INDICES = TensorIndices([1], [], [])
-BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+NNG_NO_INDICES = TensorIndices([], [], [])
+NNG_IFM_INDICES = TensorIndices([0], [], [])
+NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+NNG_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
# Static information related to operation codes
@@ -117,7 +117,7 @@ class OperatorInfo:
__slots__ = ("id", "block_type", "indices", "is_unary")
_id = 0
- def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
+ def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False):
OperatorInfo._id += 1
self.id = OperatorInfo._id
self.block_type = block_type
@@ -127,37 +127,38 @@ class OperatorInfo:
# Internally used operation codes
class Op(Enum):
- Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
- Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
+ Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
AddN = OperatorInfo()
Any = OperatorInfo()
ArgMax = OperatorInfo()
ArgMin = OperatorInfo()
- AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
BatchMatMul = OperatorInfo()
BatchToSpaceND = OperatorInfo()
- BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
- BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
- BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
+ BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+ BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+ BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES)
CLZ = OperatorInfo(
- block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
+ block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True
) # NPU specific operation
Call = OperatorInfo()
Cast = OperatorInfo()
Ceil = OperatorInfo()
+ Clamp = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max
- Concat = OperatorInfo(indices=CONCAT_INDICES)
+ Concat = OperatorInfo(indices=NNG_CONCAT_INDICES)
ConcatEmbeddings = OperatorInfo()
- ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
- ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES)
+ ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES)
+ ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES)
Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
- Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
- Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
+ Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
+ Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES)
Conv2DBackpropInputSwitchedBias = OperatorInfo(
- block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
+ block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES
)
- Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
+ Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
Cos = OperatorInfo()
Cumsum = OperatorInfo()
Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
@@ -165,26 +166,28 @@ class Op(Enum):
Delegate = OperatorInfo()
Densify = OperatorInfo()
DepthToSpace = OperatorInfo()
- DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
- Dequantize = OperatorInfo(indices=IFM_INDICES)
+ DepthwiseConv2DBias = OperatorInfo(
+ block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES
+ )
+ Dequantize = OperatorInfo(indices=NNG_IFM_INDICES)
Div = OperatorInfo()
Elu = OperatorInfo()
EmbeddingLookup = OperatorInfo()
EmbeddingLookupSparse = OperatorInfo()
Equal = OperatorInfo()
Exp = OperatorInfo()
- ExpandDims = OperatorInfo(indices=IFM_INDICES)
+ ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES)
FakeQuantWithMinMaxArgs = OperatorInfo()
Fill = OperatorInfo()
Floor = OperatorInfo()
FloorDiv = OperatorInfo()
FloorMod = OperatorInfo()
- FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
+ FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES)
GatherNd = OperatorInfo()
GatherV2 = OperatorInfo()
Greater = OperatorInfo()
GreaterEqual = OperatorInfo()
- HardSwish = OperatorInfo(indices=IFM_INDICES)
+ HardSwish = OperatorInfo(indices=NNG_IFM_INDICES)
HashtableLookup = OperatorInfo()
Identity = OperatorInfo()
If = OperatorInfo()
@@ -192,7 +195,7 @@ class Op(Enum):
L2Pool2D = OperatorInfo()
LRN = OperatorInfo()
LSHProjection = OperatorInfo()
- LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+ LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True)
Less = OperatorInfo()
LessEqual = OperatorInfo()
Log = OperatorInfo()
@@ -200,92 +203,92 @@ class Op(Enum):
LogicalAnd = OperatorInfo()
LogicalNot = OperatorInfo()
LogicalOr = OperatorInfo()
- Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
- MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
MatrixDiag = OperatorInfo()
MatrixSetDiag = OperatorInfo()
Max = OperatorInfo()
- MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
- Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
- Mean = OperatorInfo(indices=IFM_INDICES)
+ MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+ Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
+ Mean = OperatorInfo(indices=NNG_IFM_INDICES)
Min = OperatorInfo()
- Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
MirrorPad = OperatorInfo()
- Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Neg = OperatorInfo()
NonMaxSuppressionV4 = OperatorInfo()
NonMaxSuppressionV5 = OperatorInfo()
NotEqual = OperatorInfo()
OneHot = OperatorInfo()
- Pack = OperatorInfo(indices=IFM_INDICES)
- PackReshaped = OperatorInfo(indices=IFM_INDICES)
- Pad = OperatorInfo(indices=IFM_INDICES)
+ Pack = OperatorInfo(indices=NNG_IFM_INDICES)
+ PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
+ Pad = OperatorInfo(indices=NNG_IFM_INDICES)
PadV2 = OperatorInfo()
Placeholder = OperatorInfo() # Only used in CPU subgraphs
Pow = OperatorInfo()
Prelu = OperatorInfo()
Prod = OperatorInfo()
- Quantize = OperatorInfo(indices=IFM_INDICES)
- QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
- QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
- QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
- QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
- QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
+ Quantize = OperatorInfo(indices=NNG_IFM_INDICES)
+ QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+ QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES)
+ QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+ QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
+ QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES)
Range = OperatorInfo()
Rank = OperatorInfo()
- ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
- Relu = OperatorInfo(indices=IFM_INDICES)
- Relu6 = OperatorInfo(indices=IFM_INDICES)
- ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
- ReluN = OperatorInfo(indices=IFM_INDICES) # TOSA specific
- Rescale = OperatorInfo(indices=IFM_INDICES) # TOSA specific
- RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
- Reshape = OperatorInfo(indices=IFM_INDICES)
- ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES)
+ Relu = OperatorInfo(indices=NNG_IFM_INDICES)
+ Relu6 = OperatorInfo(indices=NNG_IFM_INDICES)
+ ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
+ ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
+ Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
+ RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
+ Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
+ ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
ResizeNearestNeighbor = OperatorInfo()
ReverseSequence = OperatorInfo()
ReverseV2 = OperatorInfo()
- Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Round = OperatorInfo()
Rsqrt = OperatorInfo()
- SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
- SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
+ SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
+ SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation
ScatterNd = OperatorInfo()
SegmentSum = OperatorInfo()
Select = OperatorInfo()
SelectV2 = OperatorInfo()
Shape = OperatorInfo()
- Sigmoid = OperatorInfo(indices=IFM_INDICES)
+ Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES)
SignBit = OperatorInfo()
Sin = OperatorInfo()
SkipGram = OperatorInfo()
- Slice = OperatorInfo(indices=IFM_INDICES)
- Softmax = OperatorInfo(indices=IFM_INDICES)
+ Slice = OperatorInfo(indices=NNG_IFM_INDICES)
+ Softmax = OperatorInfo(indices=NNG_IFM_INDICES)
SpaceToBatchND = OperatorInfo()
SpaceToDepth = OperatorInfo()
SparseToDense = OperatorInfo()
- Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
- SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
- SplitV = OperatorInfo(indices=IFM_INDICES)
+ Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES)
+ SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES)
+ SplitV = OperatorInfo(indices=NNG_IFM_INDICES)
Sqrt = OperatorInfo()
Square = OperatorInfo()
SquaredDifference = OperatorInfo()
- Squeeze = OperatorInfo(indices=IFM_INDICES)
- StridedSlice = OperatorInfo(indices=IFM_INDICES)
- Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Squeeze = OperatorInfo(indices=NNG_IFM_INDICES)
+ StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES)
+ Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
Sum = OperatorInfo()
Svdf = OperatorInfo()
- Tanh = OperatorInfo(indices=IFM_INDICES)
+ Tanh = OperatorInfo(indices=NNG_IFM_INDICES)
Tile = OperatorInfo()
TopKV2 = OperatorInfo()
Transpose = OperatorInfo()
- UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
- UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
+ UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
Unique = OperatorInfo()
- Unpack = OperatorInfo(indices=IFM_INDICES)
- UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
+ Unpack = OperatorInfo(indices=NNG_IFM_INDICES)
+ UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES)
Where = OperatorInfo()
While = OperatorInfo()
ZerosLike = OperatorInfo()
@@ -323,7 +326,7 @@ class Op(Enum):
return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
def is_relu_op(self):
- return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip)
+ return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp)
def is_activation_op(self):
return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
@@ -408,7 +411,7 @@ def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFun
act.max = 1.0
elif op_type == Op.HardSwish:
act.min = 0.0
- if op_type == Op.Clip:
+ if op_type == Op.Clamp:
assert min is not None and max is not None
act.min = min
act.max = max
diff --git a/ethosu/vela/reader_util.py b/ethosu/vela/reader_util.py
index 5b454b57..233286c8 100644
--- a/ethosu/vela/reader_util.py
+++ b/ethosu/vela/reader_util.py
@@ -58,3 +58,30 @@ def fixup_tensors(input_tensors, tensors):
if not tens.ops:
op = Operation(Op.Const, tens.name)
op.set_output_tensor(tens)
+
+
+def align_inputs_indices(from_indices, to_indices, inputs):
+ to_list = to_indices.ifms + to_indices.weights + to_indices.biases
+ from_list = from_indices.ifms + from_indices.weights + from_indices.biases
+
+ assert len(to_list) == len(from_list)
+ if to_list != from_list:
+ for idx, t_idx in enumerate(to_list):
+ if t_idx >= len(inputs):
+ # Biases are allowed to be left out
+ assert t_idx in from_indices.biases and t_idx in to_indices.biases
+ continue
+ if to_list[idx] != from_list[idx]:
+ # find t_idx in from list and swap.
+ for jdx in from_list[idx:]:
+ if from_list[jdx] == t_idx:
+ inputs[idx], inputs[jdx] = inputs[jdx], inputs[idx]
+ from_list[idx], from_list[jdx] = from_list[jdx], from_list[idx]
+ break
+ assert from_list == to_list
+ return inputs
+
+
+def align_tensor_indices_to_nng(op_type, indices, inputs):
+ nng_op = Op(op_type)
+ return align_inputs_indices(indices, nng_op.info.indices, inputs)
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index a69e8d37..664a58c6 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -23,6 +23,8 @@ import pytest
from ethosu.vela.operation import Op
from ethosu.vela.tflite.TensorType import TensorType
+from ethosu.vela.tflite_mapping import TFLITE_CONV2D_BACKPROP_INDICES
+from ethosu.vela.tflite_mapping import TFLITE_IFM_WEIGHTS_BIAS_INDICES
from ethosu.vela.tflite_reader import TFLiteSubgraph
@@ -43,23 +45,25 @@ class TestTFLiteSubgraph:
assert output == expected
parse_op_testdata = [
- # op_type, opt_serializer, inputs, output, expected
- (Op.FullyConnected, None, [0, 1, 2], 3, 3), # FC
- (Op.FullyConnected, None, [0, 1, -1], 3, 3), # FC disabled Bias
- (Op.FullyConnected, None, [0, 1], 3, 3), # FC no Bias
- (Op.Conv2D, None, [2, 1, 3], 0, 3), # Conv2D
- (Op.Conv2DBackpropInput, None, [0, 1, 2, 3], 4, 4), # TransposeConv
- (Op.Conv2DBackpropInput, None, [0, 1, 2], 4, 4), # TransposeConv no Bias
- pytest.param(Op.Conv2D, None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail), # Conv2D no Weights
+ # op_type, opt_serializer, indices, inputs, output, expected
+ (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, 2], 3, 3), # FC
+ (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, -1], 3, 3), # FC disabled Bias
+ (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1], 3, 3), # FC no Bias
+ (Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [2, 1, 3], 0, 3), # Conv2D
+ (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2, 3], 4, 4), # TransposeConv
+ (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2], 4, 4), # TransposeConv no Bias
+ pytest.param(
+ Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, -1, 1], 3, 3, marks=pytest.mark.xfail
+ ), # Conv2D no Weights
]
- @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata)
- def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected):
+ @pytest.mark.parametrize("op_type, opt_serializer, indices, inputs, output, expected", parse_op_testdata)
+ def test_parse_operator(self, op_type, opt_serializer, indices, inputs, output, expected):
with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None):
# Mock a TFLiteSubGraph
sg = TFLiteSubgraph(None, None)
sg.graph = MagicMock()
- sg.graph.operator_codes = [(op_type, opt_serializer, "")]
+ sg.graph.operator_codes = [(op_type, opt_serializer, "", indices)]
# Mock a couple of tensors
sg.tensors = [MagicMock() for _ in range(5)]
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index b526ec58..23a1a2b7 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -25,6 +25,7 @@ from .data_type import DataType
from .operation import CustomType
from .operation import Op
from .operation import Padding as opPad
+from .operation import TensorIndices
from .tflite import AbsOptions
from .tflite import AddNOptions
from .tflite import AddOptions
@@ -489,50 +490,89 @@ reducer_opts = OptionsSerializer("ReducerOptions", ("keep_dims",))
is_int_vec = True
+TFLITE_NO_INDICES = TensorIndices([], [], [])
+TFLITE_IFM_INDICES = TensorIndices([0], [], [])
+TFLITE_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+TFLITE_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+TFLITE_IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+TFLITE_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+TFLITE_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+TFLITE_CONCAT_INDICES = TensorIndices([1, 2], [], [])
+TFLITE_SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+TFLITE_BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
builtin_operator_map = {
- BuiltinOperator.ADD: (Op.Add, OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16"))),
- BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts),
- BuiltinOperator.CONCATENATION: (Op.ConcatTFLite, OptionsSerializer("ConcatenationOptions", ("axis", fused_act))),
- BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts),
- BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts),
- BuiltinOperator.DEPTH_TO_SPACE: (Op.DepthToSpace, OptionsSerializer("DepthToSpaceOptions", ("block_size",))),
- BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions")),
- BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None),
- BuiltinOperator.FLOOR: (Op.Floor, None),
+ BuiltinOperator.ADD: (
+ Op.Add,
+ OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16")),
+ TFLITE_IFM_IFM2_INDICES,
+ ),
+ BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts, TFLITE_IFM_INDICES),
+ BuiltinOperator.CONCATENATION: (
+ Op.ConcatTFLite,
+ OptionsSerializer("ConcatenationOptions", ("axis", fused_act)),
+ TFLITE_CONCAT_INDICES,
+ ),
+ BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES),
+ BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES),
+ BuiltinOperator.DEPTH_TO_SPACE: (
+ Op.DepthToSpace,
+ OptionsSerializer("DepthToSpaceOptions", ("block_size",)),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None, TFLITE_NO_INDICES),
+ BuiltinOperator.FLOOR: (Op.Floor, None, TFLITE_NO_INDICES),
BuiltinOperator.FULLY_CONNECTED: (
Op.FullyConnected,
OptionsSerializer(
"FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs", "keep_num_dims")
),
+ TFLITE_IFM_WEIGHTS_BIAS_INDICES,
),
- BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None),
- BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,))),
- BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts),
+ BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None, TFLITE_NO_INDICES),
+ BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,)), TFLITE_NO_INDICES),
+ BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts, TFLITE_NO_INDICES),
BuiltinOperator.LOCAL_RESPONSE_NORMALIZATION: (
Op.LRN,
OptionsSerializer("LocalResponseNormalizationOptions", ("radius", "bias", "alpha", "beta")),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.LOGISTIC: (Op.Sigmoid, None, TFLITE_IFM_INDICES),
+ BuiltinOperator.LSH_PROJECTION: (
+ Op.LSHProjection,
+ OptionsSerializer("LSHProjectionOptions", ("type",)),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.LSTM: (Op.Lstm, lstm_opts, TFLITE_IFM_WEIGHTS_INDICES),
+ BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts, TFLITE_IFM_INDICES),
+ BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,)), TFLITE_IFM_IFM2_INDICES),
+ BuiltinOperator.RELU: (Op.Relu, None, TFLITE_IFM_INDICES),
+ BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None, TFLITE_IFM_INDICES),
+ BuiltinOperator.RELU6: (Op.Relu6, None, TFLITE_IFM_INDICES),
+ BuiltinOperator.RESHAPE: (
+ Op.Reshape,
+ OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),)),
+ TFLITE_IFM_INDICES,
),
- BuiltinOperator.LOGISTIC: (Op.Sigmoid, None),
- BuiltinOperator.LSH_PROJECTION: (Op.LSHProjection, OptionsSerializer("LSHProjectionOptions", ("type",))),
- BuiltinOperator.LSTM: (Op.Lstm, lstm_opts),
- BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts),
- BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,))),
- BuiltinOperator.RELU: (Op.Relu, None),
- BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None),
- BuiltinOperator.RELU6: (Op.Relu6, None),
- BuiltinOperator.RESHAPE: (Op.Reshape, OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),))),
BuiltinOperator.RESIZE_BILINEAR: (
Op.ResizeBilinear,
OptionsSerializer("ResizeBilinearOptions", ("align_corners", "half_pixel_centers")),
+ TFLITE_IFM_INDICES,
+ ),
+ BuiltinOperator.RNN: (Op.Rnn, rnn_opts, TFLITE_IFM_WEIGHTS_INDICES),
+ BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",)), TFLITE_IFM_INDICES),
+ BuiltinOperator.SPACE_TO_DEPTH: (
+ Op.SpaceToDepth,
+ OptionsSerializer("SpaceToDepthOptions", ("block_size",)),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.RNN: (Op.Rnn, rnn_opts),
- BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",))),
- BuiltinOperator.SPACE_TO_DEPTH: (Op.SpaceToDepth, OptionsSerializer("SpaceToDepthOptions", ("block_size",))),
BuiltinOperator.SVDF: (
Op.Svdf,
OptionsSerializer("SVDFOptions", ("rank", fused_act, "asymmetric_quantize_inputs")),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.TANH: (Op.Tanh, None),
+ BuiltinOperator.TANH: (Op.Tanh, None, TFLITE_IFM_INDICES),
BuiltinOperator.CONCAT_EMBEDDINGS: (
Op.ConcatEmbeddings,
OptionsSerializer(
@@ -547,40 +587,76 @@ builtin_operator_map = {
"embedding_dim_per_channel_as_length",
),
),
+ TFLITE_NO_INDICES,
),
BuiltinOperator.SKIP_GRAM: (
Op.SkipGram,
OptionsSerializer("SkipGramOptions", ("ngram_size", "max_skip_size", "include_all_ngrams")),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",))),
+ BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",)), TFLITE_NO_INDICES),
BuiltinOperator.EMBEDDING_LOOKUP_SPARSE: (
Op.EmbeddingLookupSparse,
OptionsSerializer("EmbeddingLookupSparseOptions", ("combiner",)),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (
+ Op.UnidirectionalSequenceRnn,
+ seq_rnn_opts,
+ TFLITE_IFM_WEIGHTS_INDICES,
+ ),
+ BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",)), TFLITE_NO_INDICES),
+ BuiltinOperator.BATCH_TO_SPACE_ND: (
+ Op.BatchToSpaceND,
+ OptionsSerializer("BatchToSpaceNDOptions"),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.SPACE_TO_BATCH_ND: (
+ Op.SpaceToBatchND,
+ OptionsSerializer("SpaceToBatchNDOptions"),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.MEAN: (Op.Mean, reducer_opts, TFLITE_IFM_INDICES),
+ BuiltinOperator.SUB: (
+ Op.Sub,
+ OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",)),
+ TFLITE_IFM_IFM2_INDICES,
+ ),
+ BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,)), TFLITE_NO_INDICES),
+ BuiltinOperator.SQUEEZE: (
+ Op.Squeeze,
+ OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),)),
+ TFLITE_IFM_INDICES,
+ ),
+ BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (
+ Op.UnidirectionalSequenceLstm,
+ unidir_seq_lstm_opts,
+ TFLITE_IFM_WEIGHTS_INDICES,
),
- BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions")),
- BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (Op.UnidirectionalSequenceRnn, seq_rnn_opts),
- BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",))),
- BuiltinOperator.BATCH_TO_SPACE_ND: (Op.BatchToSpaceND, OptionsSerializer("BatchToSpaceNDOptions")),
- BuiltinOperator.SPACE_TO_BATCH_ND: (Op.SpaceToBatchND, OptionsSerializer("SpaceToBatchNDOptions")),
- BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions")),
- BuiltinOperator.MEAN: (Op.Mean, reducer_opts),
- BuiltinOperator.SUB: (Op.Sub, OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))),
- BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,))),
- BuiltinOperator.SQUEEZE: (Op.Squeeze, OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))),
- BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (Op.UnidirectionalSequenceLstm, unidir_seq_lstm_opts),
BuiltinOperator.STRIDED_SLICE: (
Op.StridedSlice,
OptionsSerializer(
- "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask")
+ "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask"),
),
+ TFLITE_IFM_INDICES,
+ ),
+ BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (
+ Op.BidirectionalSequenceRnn,
+ bidir_seq_rnn_opts,
+ TFLITE_IFM_WEIGHTS_INDICES,
+ ),
+ BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options"), TFLITE_NO_INDICES),
+ BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",)), TFLITE_SPLIT_IFM_INDICES),
+ BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.DELEGATE: (Op.Delegate, None, TFLITE_NO_INDICES),
+ BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (
+ Op.BidirectionalSequenceLstm,
+ bidir_seq_lstm_opts,
+ TFLITE_IFM_WEIGHTS_INDICES,
),
- BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (Op.BidirectionalSequenceRnn, bidir_seq_rnn_opts),
- BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions")),
- BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options")),
- BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",))),
- BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions")),
- BuiltinOperator.DELEGATE: (Op.Delegate, None),
- BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (Op.BidirectionalSequenceLstm, bidir_seq_lstm_opts),
BuiltinOperator.CAST: (
Op.Cast,
OptionsSerializer(
@@ -590,117 +666,152 @@ builtin_operator_map = {
("out_data_type", datatype_deserialize, datatype_serialize),
),
),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.PRELU: (Op.Prelu, None),
- BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions")),
+ BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_NO_INDICES),
+ BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
BuiltinOperator.ARG_MAX: (
Op.ArgMax,
OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions")),
- BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions")),
- BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions")),
- BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options")),
- BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions")),
- BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions")),
- BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions")),
- BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions")),
- BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions")),
- BuiltinOperator.SIN: (Op.Sin, None),
+ BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES),
+ BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options"), TFLITE_NO_INDICES),
+ BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.SIN: (Op.Sin, None, TFLITE_NO_INDICES),
BuiltinOperator.TRANSPOSE_CONV: (
Op.Conv2DBackpropInput,
OptionsSerializer("TransposeConvOptions", (padding, "stride_w", "stride_h")),
+ TFLITE_CONV2D_BACKPROP_INDICES,
),
BuiltinOperator.SPARSE_TO_DENSE: (
Op.SparseToDense,
OptionsSerializer("SparseToDenseOptions", ("validate_indices",)),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions")),
- BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions")),
- BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions")),
- BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions")),
- BuiltinOperator.LOG: (Op.Log, None),
- BuiltinOperator.SUM: (Op.Sum, reducer_opts),
- BuiltinOperator.SQRT: (Op.Sqrt, None),
- BuiltinOperator.RSQRT: (Op.Rsqrt, None),
+ BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.LOG: (Op.Log, None, TFLITE_NO_INDICES),
+ BuiltinOperator.SUM: (Op.Sum, reducer_opts, TFLITE_NO_INDICES),
+ BuiltinOperator.SQRT: (Op.Sqrt, None, TFLITE_NO_INDICES),
+ BuiltinOperator.RSQRT: (Op.Rsqrt, None, TFLITE_NO_INDICES),
BuiltinOperator.SHAPE: (
Op.Shape,
OptionsSerializer("ShapeOptions", (("out_type", datatype_deserialize, datatype_serialize),)),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions")),
+ BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions"), TFLITE_NO_INDICES),
BuiltinOperator.ARG_MIN: (
Op.ArgMin,
OptionsSerializer("ArgMinOptions", (("output_type", datatype_deserialize, datatype_serialize),)),
+ TFLITE_NO_INDICES,
),
BuiltinOperator.FAKE_QUANT: (
Op.FakeQuantWithMinMaxArgs,
OptionsSerializer("FakeQuantOptions", ("min", "max", "num_bits", "narrow_range")),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts),
- BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts),
- BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis"))),
- BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions")),
- BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",))),
- BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions")),
- BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions")),
- BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis"))),
- BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts),
- BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions")),
- BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts),
- BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions")),
- BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions")),
- BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions")),
- BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions")),
- BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions")),
+ BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts, TFLITE_NO_INDICES),
+ BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts, TFLITE_NO_INDICES),
+ BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis")), TFLITE_IFM_INDICES),
+ BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",)), TFLITE_NO_INDICES),
+ BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis")), TFLITE_IFM_INDICES),
+ BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts, TFLITE_NO_INDICES),
+ BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts, TFLITE_NO_INDICES),
+ BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions"), TFLITE_NO_INDICES),
BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: (
Op.ResizeNearestNeighbor,
OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",)), TFLITE_IFM_INDICES),
+ BuiltinOperator.SQUARED_DIFFERENCE: (
+ Op.SquaredDifference,
+ OptionsSerializer("SquaredDifferenceOptions"),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",))),
- BuiltinOperator.SQUARED_DIFFERENCE: (Op.SquaredDifference, OptionsSerializer("SquaredDifferenceOptions")),
- BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",))),
- BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions")),
- BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",))),
+ BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",)), TFLITE_NO_INDICES),
+ BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",)), TFLITE_IFM_INDICES),
BuiltinOperator.UNIQUE: (
Op.Unique,
OptionsSerializer("UniqueOptions", (("idx_out_type", datatype_deserialize, datatype_serialize),)),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.CEIL: (Op.Ceil, None),
- BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options")),
- BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions")),
- BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions")),
- BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions")),
- BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions")),
- BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions")),
- BuiltinOperator.ELU: (Op.Elu, None),
+ BuiltinOperator.CEIL: (Op.Ceil, None, TFLITE_NO_INDICES),
+ BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options"), TFLITE_NO_INDICES),
+ BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.ELU: (Op.Elu, None, TFLITE_NO_INDICES),
BuiltinOperator.REVERSE_SEQUENCE: (
Op.ReverseSequence,
OptionsSerializer("ReverseSequenceOptions", ("seq_dim", "batch_dim")),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.ROUND: (Op.Round, None, TFLITE_NO_INDICES),
+ BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions"), TFLITE_IFM_INDICES),
+ BuiltinOperator.IF: (
+ Op.If,
+ OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index")),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions")),
- BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions")),
- BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions")),
- BuiltinOperator.ROUND: (Op.Round, None),
- BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions")),
- BuiltinOperator.IF: (Op.If, OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))),
BuiltinOperator.WHILE: (
Op.While,
OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index")),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.NON_MAX_SUPPRESSION_V4: (
+ Op.NonMaxSuppressionV4,
+ OptionsSerializer("NonMaxSuppressionV4Options"),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.NON_MAX_SUPPRESSION_V5: (
+ Op.NonMaxSuppressionV5,
+ OptionsSerializer("NonMaxSuppressionV5Options"),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options"), TFLITE_NO_INDICES),
+ BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions"), TFLITE_NO_INDICES),
+ BuiltinOperator.BATCH_MATMUL: (
+ Op.BatchMatMul,
+ OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y")),
+ TFLITE_NO_INDICES,
+ ),
+ BuiltinOperator.CUMSUM: (
+ Op.Cumsum,
+ OptionsSerializer("CumsumOptions", ("exclusive", "reverse")),
+ TFLITE_NO_INDICES,
),
- BuiltinOperator.NON_MAX_SUPPRESSION_V4: (Op.NonMaxSuppressionV4, OptionsSerializer("NonMaxSuppressionV4Options")),
- BuiltinOperator.NON_MAX_SUPPRESSION_V5: (Op.NonMaxSuppressionV5, OptionsSerializer("NonMaxSuppressionV5Options")),
- BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions")),
- BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options")),
- BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions")),
- BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions")),
- BuiltinOperator.BATCH_MATMUL: (Op.BatchMatMul, OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y"))),
- BuiltinOperator.CUMSUM: (Op.Cumsum, OptionsSerializer("CumsumOptions", ("exclusive", "reverse"))),
- BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer()),
+ BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer(), TFLITE_NO_INDICES),
}
-builtin_operator_inv_map = {v[0]: (k, v[1]) for k, v in builtin_operator_map.items()}
+builtin_operator_inv_map = {v[0]: (k, v[1], v[2]) for k, v in builtin_operator_map.items()}
-builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer())
+builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer(), TFLITE_NO_INDICES)
BUILTIN_OPERATOR_UNKNOWN = "UNKNOWN"
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 1a45a5ee..30bf32af 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -27,6 +27,7 @@ from .nn_graph import Subgraph
from .operation import create_activation_function
from .operation import Op
from .operation import Operation
+from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
@@ -112,7 +113,7 @@ class TFLiteSubgraph:
return tens
def parse_operator(self, op_index, op_data):
- op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()]
+ op_type, opt_serializer, custom_code, indices = self.graph.operator_codes[op_data.OpcodeIndex()]
inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
intermediates = []
@@ -122,6 +123,7 @@ class TFLiteSubgraph:
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
+ inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
op = Operation(op_type, name)
op.op_index = op_index
op.inputs = inputs
@@ -263,11 +265,11 @@ class TFLiteGraph:
raise InputFileError(
self.name, f"The input file contains operator code '{c}' which is currently not supported"
)
- op_type, ser = builtin_operator_map[c]
+ op_type, ser, indices = builtin_operator_map[c]
custom_code = None
if c == BuiltinOperator.CUSTOM:
custom_code = decode_str(code.CustomCode())
- return op_type, ser, custom_code
+ return op_type, ser, custom_code, indices
def read_tflite(filename, batch_size, feed_dict, output_node_names, initialisation_nodes):
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 8cabb0ac..3701893e 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -24,6 +24,7 @@ from flatbuffers.builder import UOffsetTFlags
from .errors import VelaError
from .nn_graph import PassPlacement
from .operation import Op
+from .reader_util import align_inputs_indices
from .tensor import MemType
from .tensor import TensorPurpose
from .tflite import Buffer
@@ -38,7 +39,6 @@ from .tflite_mapping import builtin_operator_inv_map
from .tflite_mapping import BuiltinOperator
from .tflite_mapping import datatype_inv_map
-
# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
tflite_version = 3
@@ -90,6 +90,8 @@ class TFLiteSerialiser:
for ps in sg.passes:
for op in ps.ops:
if op.type not in self.ops_to_ignore:
+ # swap from nng input indexing to TensorFlow Lite input indexing
+ self.align_nng_inputs_to_tflite(op)
all_ops.append(op)
if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
# If values are None op has non-constant weights
@@ -104,6 +106,11 @@ class TFLiteSerialiser:
self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
self.operator_code_map = {}
+ def align_nng_inputs_to_tflite(self, op):
+ from_indices = op.type.info.indices
+ _, _, to_indices = builtin_operator_inv_map[op.type]
+ op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs)
+
def write_byte_vector(self, v, alignment=1):
builder = self.builder
builder.StartVector(1, len(v), alignment)
@@ -170,13 +177,13 @@ class TFLiteSerialiser:
builder = self.builder
custom_code_offset = None
if op_type == Op.Custom:
- tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+ tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
custom_code_offset = builder.CreateString(custom_code)
else:
assert (
op_type in builtin_operator_inv_map
), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
- tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+ tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type]
if op_type == Op.CustomNpuOp:
assert (
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 94e6f999..fe18ce35 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -60,7 +60,7 @@ def add_padding_fields(op, arch, nng):
def rewrite_activation(op, arch, nng):
- if not op.type.is_relu_op():
+ if op.type not in (Op.ReluN, Op.Clamp):
return op
ifm = op.ifm
@@ -82,7 +82,7 @@ def rewrite_activation(op, arch, nng):
if op.ofm.quantization.zero_point is None:
op.ofm.quantization.zero_point = zp
- if op.type == Op.Clip:
+ if op.type == Op.Clamp:
op.attrs["min"] = op.attrs["min_int"] - zp
op.attrs["max"] = op.attrs["max_int"] - zp
elif op.type == Op.ReluN:
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 82f61f7c..312ac92e 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -249,7 +249,7 @@ tosa_operator_map = {
# TODO TosaOp.MATMUL:
TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info)
- TosaOp.CLAMP: (Op.Clip, clamp_attrs, None, TOSA_IFM_INDICES),
+ TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES),
TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.SIGMOID
# TODO TosaOp.TANH
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index ac0b3969..e51ead1d 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -25,6 +25,7 @@ from .nn_graph import Graph
from .nn_graph import Subgraph
from .operation import Op
from .operation import Operation
+from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
@@ -104,8 +105,8 @@ class TosaSubgraph:
name = "unknown_op_name"
if len(outputs):
name = outputs[0].name
+ inputs = align_tensor_indices_to_nng(op_type, indices, inputs)
op = Operation(op_type, name)
- op.type.info.indices = indices
op.op_index = op_index
op.inputs = inputs
op.outputs = outputs
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index c87d653a..51f80ebd 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -32,7 +32,7 @@ class TosaSupportedOperators:
mac_main_ops = convolution_like_ops
type_conversion_ops = set((Op.Rescale,))
- relu_ops = set((Op.Clip, Op.ReluN,))
+ relu_ops = set((Op.Clamp, Op.ReluN,))
activation_ops = relu_ops
npu_post_ops = activation_ops