From 5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 30 Jun 2021 09:07:16 +0200 Subject: MLBEDSW-4840 Move setting of input indices to tflite reader Mapping to internal input indexing has been added to tflite_reader.py and tosa_reader.py. And the other way around in tflite_writer.py. Signed-off-by: Patrik Gustavsson Change-Id: I4d8596e747cfa7c4203884c4e785eb1977e2bcc1 --- ethosu/vela/operation.py | 153 ++++++++------- ethosu/vela/reader_util.py | 27 +++ ethosu/vela/test/test_tflite_reader.py | 26 +-- ethosu/vela/tflite_mapping.py | 335 +++++++++++++++++++++----------- ethosu/vela/tflite_reader.py | 8 +- ethosu/vela/tflite_writer.py | 13 +- ethosu/vela/tosa_graph_optimiser.py | 4 +- ethosu/vela/tosa_mapping.py | 2 +- ethosu/vela/tosa_reader.py | 3 +- ethosu/vela/tosa_supported_operators.py | 2 +- 10 files changed, 364 insertions(+), 209 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 0558e527..ffa4717d 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -100,16 +100,16 @@ class CustomType(Enum): TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"]) -NO_INDICES = TensorIndices([], [], []) -IFM_INDICES = TensorIndices([0], [], []) -IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) -IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) -IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) -CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) -TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) -CONCAT_INDICES = TensorIndices([1, 2], [], []) -SPLIT_IFM_INDICES = TensorIndices([1], [], []) -BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) +NNG_NO_INDICES = TensorIndices([], [], []) +NNG_IFM_INDICES = TensorIndices([0], [], []) +NNG_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) +NNG_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) +NNG_IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) +NNG_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) +NNG_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) +NNG_CONCAT_INDICES = TensorIndices([1, 2], [], []) +NNG_SPLIT_IFM_INDICES = TensorIndices([1], [], []) +NNG_BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) # Static information related to operation codes @@ -117,7 +117,7 @@ class OperatorInfo: __slots__ = ("id", "block_type", "indices", "is_unary") _id = 0 - def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False): + def __init__(self, block_type=NpuBlockType.Default, indices=NNG_NO_INDICES, is_unary=False): OperatorInfo._id += 1 self.id = OperatorInfo._id self.block_type = block_type @@ -127,37 +127,38 @@ class OperatorInfo: # Internally used operation codes class Op(Enum): - Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True) - Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True) + Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) AddN = OperatorInfo() Any = OperatorInfo() ArgMax = OperatorInfo() ArgMin = OperatorInfo() - AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) BatchMatMul = OperatorInfo() BatchToSpaceND = OperatorInfo() - BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) - BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) - BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES) + BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) + BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) + BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_BLOCK_LSTM_INDICES) CLZ = OperatorInfo( - block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True + block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True ) # NPU specific operation Call = OperatorInfo() Cast = OperatorInfo() Ceil = OperatorInfo() + Clamp = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific Clip = OperatorInfo() # NPU specific fused activation function for clipping between activation.min/max - Concat = OperatorInfo(indices=CONCAT_INDICES) + Concat = OperatorInfo(indices=NNG_CONCAT_INDICES) ConcatEmbeddings = OperatorInfo() - ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES) - ConcatTFLite = OperatorInfo(indices=CONCAT_INDICES) + ConcatSliceWrite = OperatorInfo(indices=NNG_IFM_INDICES) + ConcatTFLite = OperatorInfo(indices=NNG_CONCAT_INDICES) Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs - Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) - Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES) + Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES) + Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_CONV2D_BACKPROP_INDICES) Conv2DBackpropInputSwitchedBias = OperatorInfo( - block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES + block_type=NpuBlockType.ConvolutionMxN, indices=NNG_TRANSPOSE_CONV_INDICES ) - Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES) + Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_BIAS_INDICES) Cos = OperatorInfo() Cumsum = OperatorInfo() Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs @@ -165,26 +166,28 @@ class Op(Enum): Delegate = OperatorInfo() Densify = OperatorInfo() DepthToSpace = OperatorInfo() - DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES) - Dequantize = OperatorInfo(indices=IFM_INDICES) + DepthwiseConv2DBias = OperatorInfo( + block_type=NpuBlockType.ConvolutionDepthWise, indices=NNG_IFM_WEIGHTS_BIAS_INDICES + ) + Dequantize = OperatorInfo(indices=NNG_IFM_INDICES) Div = OperatorInfo() Elu = OperatorInfo() EmbeddingLookup = OperatorInfo() EmbeddingLookupSparse = OperatorInfo() Equal = OperatorInfo() Exp = OperatorInfo() - ExpandDims = OperatorInfo(indices=IFM_INDICES) + ExpandDims = OperatorInfo(indices=NNG_IFM_INDICES) FakeQuantWithMinMaxArgs = OperatorInfo() Fill = OperatorInfo() Floor = OperatorInfo() FloorDiv = OperatorInfo() FloorMod = OperatorInfo() - FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES) + FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_BIAS_INDICES) GatherNd = OperatorInfo() GatherV2 = OperatorInfo() Greater = OperatorInfo() GreaterEqual = OperatorInfo() - HardSwish = OperatorInfo(indices=IFM_INDICES) + HardSwish = OperatorInfo(indices=NNG_IFM_INDICES) HashtableLookup = OperatorInfo() Identity = OperatorInfo() If = OperatorInfo() @@ -192,7 +195,7 @@ class Op(Enum): L2Pool2D = OperatorInfo() LRN = OperatorInfo() LSHProjection = OperatorInfo() - LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True) + LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_INDICES, is_unary=True) Less = OperatorInfo() LessEqual = OperatorInfo() Log = OperatorInfo() @@ -200,92 +203,92 @@ class Op(Enum): LogicalAnd = OperatorInfo() LogicalNot = OperatorInfo() LogicalOr = OperatorInfo() - Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions - MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) MatrixDiag = OperatorInfo() MatrixSetDiag = OperatorInfo() Max = OperatorInfo() - MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) - Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) - Mean = OperatorInfo(indices=IFM_INDICES) + MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) + Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) + Mean = OperatorInfo(indices=NNG_IFM_INDICES) Min = OperatorInfo() - Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) MirrorPad = OperatorInfo() - Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) Neg = OperatorInfo() NonMaxSuppressionV4 = OperatorInfo() NonMaxSuppressionV5 = OperatorInfo() NotEqual = OperatorInfo() OneHot = OperatorInfo() - Pack = OperatorInfo(indices=IFM_INDICES) - PackReshaped = OperatorInfo(indices=IFM_INDICES) - Pad = OperatorInfo(indices=IFM_INDICES) + Pack = OperatorInfo(indices=NNG_IFM_INDICES) + PackReshaped = OperatorInfo(indices=NNG_IFM_INDICES) + Pad = OperatorInfo(indices=NNG_IFM_INDICES) PadV2 = OperatorInfo() Placeholder = OperatorInfo() # Only used in CPU subgraphs Pow = OperatorInfo() Prelu = OperatorInfo() Prod = OperatorInfo() - Quantize = OperatorInfo(indices=IFM_INDICES) - QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) - QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) - QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) - QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) - QuantizedReshape = OperatorInfo(indices=IFM_INDICES) + Quantize = OperatorInfo(indices=NNG_IFM_INDICES) + QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) + QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=NNG_IFM_WEIGHTS_INDICES) + QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) + QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) + QuantizedReshape = OperatorInfo(indices=NNG_IFM_INDICES) Range = OperatorInfo() Rank = OperatorInfo() - ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES) - Relu = OperatorInfo(indices=IFM_INDICES) - Relu6 = OperatorInfo(indices=IFM_INDICES) - ReluN1To1 = OperatorInfo(indices=IFM_INDICES) - ReluN = OperatorInfo(indices=IFM_INDICES) # TOSA specific - Rescale = OperatorInfo(indices=IFM_INDICES) # TOSA specific - RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) - Reshape = OperatorInfo(indices=IFM_INDICES) - ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=NNG_IFM_INDICES) + Relu = OperatorInfo(indices=NNG_IFM_INDICES) + Relu6 = OperatorInfo(indices=NNG_IFM_INDICES) + ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES) + ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific + Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific + RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) + Reshape = OperatorInfo(indices=NNG_IFM_INDICES) + ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES) ResizeNearestNeighbor = OperatorInfo() ReverseSequence = OperatorInfo() ReverseV2 = OperatorInfo() - Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) Round = OperatorInfo() Rsqrt = OperatorInfo() - SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation - SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation + SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation + SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) # NPU specific operation ScatterNd = OperatorInfo() SegmentSum = OperatorInfo() Select = OperatorInfo() SelectV2 = OperatorInfo() Shape = OperatorInfo() - Sigmoid = OperatorInfo(indices=IFM_INDICES) + Sigmoid = OperatorInfo(indices=NNG_IFM_INDICES) SignBit = OperatorInfo() Sin = OperatorInfo() SkipGram = OperatorInfo() - Slice = OperatorInfo(indices=IFM_INDICES) - Softmax = OperatorInfo(indices=IFM_INDICES) + Slice = OperatorInfo(indices=NNG_IFM_INDICES) + Softmax = OperatorInfo(indices=NNG_IFM_INDICES) SpaceToBatchND = OperatorInfo() SpaceToDepth = OperatorInfo() SparseToDense = OperatorInfo() - Split = OperatorInfo(indices=SPLIT_IFM_INDICES) - SplitSliceRead = OperatorInfo(indices=IFM_INDICES) - SplitV = OperatorInfo(indices=IFM_INDICES) + Split = OperatorInfo(indices=NNG_SPLIT_IFM_INDICES) + SplitSliceRead = OperatorInfo(indices=NNG_IFM_INDICES) + SplitV = OperatorInfo(indices=NNG_IFM_INDICES) Sqrt = OperatorInfo() Square = OperatorInfo() SquaredDifference = OperatorInfo() - Squeeze = OperatorInfo(indices=IFM_INDICES) - StridedSlice = OperatorInfo(indices=IFM_INDICES) - Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Squeeze = OperatorInfo(indices=NNG_IFM_INDICES) + StridedSlice = OperatorInfo(indices=NNG_IFM_INDICES) + Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES) SubgraphInput = OperatorInfo() # Only used in CPU subgraphs Sum = OperatorInfo() Svdf = OperatorInfo() - Tanh = OperatorInfo(indices=IFM_INDICES) + Tanh = OperatorInfo(indices=NNG_IFM_INDICES) Tile = OperatorInfo() TopKV2 = OperatorInfo() Transpose = OperatorInfo() - UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) - UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) + UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES) Unique = OperatorInfo() - Unpack = OperatorInfo(indices=IFM_INDICES) - UnpackReshaped = OperatorInfo(indices=IFM_INDICES) + Unpack = OperatorInfo(indices=NNG_IFM_INDICES) + UnpackReshaped = OperatorInfo(indices=NNG_IFM_INDICES) Where = OperatorInfo() While = OperatorInfo() ZerosLike = OperatorInfo() @@ -323,7 +326,7 @@ class Op(Enum): return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary def is_relu_op(self): - return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip) + return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.ReluN, Op.Clip, Op.Clamp) def is_activation_op(self): return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish) @@ -408,7 +411,7 @@ def create_activation_function(op_type: Op, min=None, max=None) -> ActivationFun act.max = 1.0 elif op_type == Op.HardSwish: act.min = 0.0 - if op_type == Op.Clip: + if op_type == Op.Clamp: assert min is not None and max is not None act.min = min act.max = max diff --git a/ethosu/vela/reader_util.py b/ethosu/vela/reader_util.py index 5b454b57..233286c8 100644 --- a/ethosu/vela/reader_util.py +++ b/ethosu/vela/reader_util.py @@ -58,3 +58,30 @@ def fixup_tensors(input_tensors, tensors): if not tens.ops: op = Operation(Op.Const, tens.name) op.set_output_tensor(tens) + + +def align_inputs_indices(from_indices, to_indices, inputs): + to_list = to_indices.ifms + to_indices.weights + to_indices.biases + from_list = from_indices.ifms + from_indices.weights + from_indices.biases + + assert len(to_list) == len(from_list) + if to_list != from_list: + for idx, t_idx in enumerate(to_list): + if t_idx >= len(inputs): + # Biases are allowed to be left out + assert t_idx in from_indices.biases and t_idx in to_indices.biases + continue + if to_list[idx] != from_list[idx]: + # find t_idx in from list and swap. + for jdx in from_list[idx:]: + if from_list[jdx] == t_idx: + inputs[idx], inputs[jdx] = inputs[jdx], inputs[idx] + from_list[idx], from_list[jdx] = from_list[jdx], from_list[idx] + break + assert from_list == to_list + return inputs + + +def align_tensor_indices_to_nng(op_type, indices, inputs): + nng_op = Op(op_type) + return align_inputs_indices(indices, nng_op.info.indices, inputs) diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py index a69e8d37..664a58c6 100644 --- a/ethosu/vela/test/test_tflite_reader.py +++ b/ethosu/vela/test/test_tflite_reader.py @@ -23,6 +23,8 @@ import pytest from ethosu.vela.operation import Op from ethosu.vela.tflite.TensorType import TensorType +from ethosu.vela.tflite_mapping import TFLITE_CONV2D_BACKPROP_INDICES +from ethosu.vela.tflite_mapping import TFLITE_IFM_WEIGHTS_BIAS_INDICES from ethosu.vela.tflite_reader import TFLiteSubgraph @@ -43,23 +45,25 @@ class TestTFLiteSubgraph: assert output == expected parse_op_testdata = [ - # op_type, opt_serializer, inputs, output, expected - (Op.FullyConnected, None, [0, 1, 2], 3, 3), # FC - (Op.FullyConnected, None, [0, 1, -1], 3, 3), # FC disabled Bias - (Op.FullyConnected, None, [0, 1], 3, 3), # FC no Bias - (Op.Conv2D, None, [2, 1, 3], 0, 3), # Conv2D - (Op.Conv2DBackpropInput, None, [0, 1, 2, 3], 4, 4), # TransposeConv - (Op.Conv2DBackpropInput, None, [0, 1, 2], 4, 4), # TransposeConv no Bias - pytest.param(Op.Conv2D, None, [0, -1, 1], 3, 3, marks=pytest.mark.xfail), # Conv2D no Weights + # op_type, opt_serializer, indices, inputs, output, expected + (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, 2], 3, 3), # FC + (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1, -1], 3, 3), # FC disabled Bias + (Op.FullyConnected, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, 1], 3, 3), # FC no Bias + (Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [2, 1, 3], 0, 3), # Conv2D + (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2, 3], 4, 4), # TransposeConv + (Op.Conv2DBackpropInput, None, TFLITE_CONV2D_BACKPROP_INDICES, [0, 1, 2], 4, 4), # TransposeConv no Bias + pytest.param( + Op.Conv2DBias, None, TFLITE_IFM_WEIGHTS_BIAS_INDICES, [0, -1, 1], 3, 3, marks=pytest.mark.xfail + ), # Conv2D no Weights ] - @pytest.mark.parametrize("op_type, opt_serializer, inputs, output, expected", parse_op_testdata) - def test_parse_operator(self, op_type, opt_serializer, inputs, output, expected): + @pytest.mark.parametrize("op_type, opt_serializer, indices, inputs, output, expected", parse_op_testdata) + def test_parse_operator(self, op_type, opt_serializer, indices, inputs, output, expected): with patch.object(TFLiteSubgraph, "__init__", lambda self, graph, subraph: None): # Mock a TFLiteSubGraph sg = TFLiteSubgraph(None, None) sg.graph = MagicMock() - sg.graph.operator_codes = [(op_type, opt_serializer, "")] + sg.graph.operator_codes = [(op_type, opt_serializer, "", indices)] # Mock a couple of tensors sg.tensors = [MagicMock() for _ in range(5)] diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index b526ec58..23a1a2b7 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -25,6 +25,7 @@ from .data_type import DataType from .operation import CustomType from .operation import Op from .operation import Padding as opPad +from .operation import TensorIndices from .tflite import AbsOptions from .tflite import AddNOptions from .tflite import AddOptions @@ -489,50 +490,89 @@ reducer_opts = OptionsSerializer("ReducerOptions", ("keep_dims",)) is_int_vec = True +TFLITE_NO_INDICES = TensorIndices([], [], []) +TFLITE_IFM_INDICES = TensorIndices([0], [], []) +TFLITE_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) +TFLITE_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) +TFLITE_IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) +TFLITE_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) +TFLITE_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) +TFLITE_CONCAT_INDICES = TensorIndices([1, 2], [], []) +TFLITE_SPLIT_IFM_INDICES = TensorIndices([1], [], []) +TFLITE_BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) + builtin_operator_map = { - BuiltinOperator.ADD: (Op.Add, OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16"))), - BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts), - BuiltinOperator.CONCATENATION: (Op.ConcatTFLite, OptionsSerializer("ConcatenationOptions", ("axis", fused_act))), - BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts), - BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts), - BuiltinOperator.DEPTH_TO_SPACE: (Op.DepthToSpace, OptionsSerializer("DepthToSpaceOptions", ("block_size",))), - BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions")), - BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None), - BuiltinOperator.FLOOR: (Op.Floor, None), + BuiltinOperator.ADD: ( + Op.Add, + OptionsSerializer("AddOptions", (fused_act, "pot_scale_int16")), + TFLITE_IFM_IFM2_INDICES, + ), + BuiltinOperator.AVERAGE_POOL_2D: (Op.AvgPool, pool2d_opts, TFLITE_IFM_INDICES), + BuiltinOperator.CONCATENATION: ( + Op.ConcatTFLite, + OptionsSerializer("ConcatenationOptions", ("axis", fused_act)), + TFLITE_CONCAT_INDICES, + ), + BuiltinOperator.CONV_2D: (Op.Conv2DBias, conv2d_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES), + BuiltinOperator.DEPTHWISE_CONV_2D: (Op.DepthwiseConv2DBias, depthwise_opts, TFLITE_IFM_WEIGHTS_BIAS_INDICES), + BuiltinOperator.DEPTH_TO_SPACE: ( + Op.DepthToSpace, + OptionsSerializer("DepthToSpaceOptions", ("block_size",)), + TFLITE_NO_INDICES, + ), + BuiltinOperator.DEQUANTIZE: (Op.Dequantize, OptionsSerializer("DequantizeOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.EMBEDDING_LOOKUP: (Op.EmbeddingLookup, None, TFLITE_NO_INDICES), + BuiltinOperator.FLOOR: (Op.Floor, None, TFLITE_NO_INDICES), BuiltinOperator.FULLY_CONNECTED: ( Op.FullyConnected, OptionsSerializer( "FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs", "keep_num_dims") ), + TFLITE_IFM_WEIGHTS_BIAS_INDICES, ), - BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None), - BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,))), - BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts), + BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None, TFLITE_NO_INDICES), + BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,)), TFLITE_NO_INDICES), + BuiltinOperator.L2_POOL_2D: (Op.L2Pool2D, pool2d_opts, TFLITE_NO_INDICES), BuiltinOperator.LOCAL_RESPONSE_NORMALIZATION: ( Op.LRN, OptionsSerializer("LocalResponseNormalizationOptions", ("radius", "bias", "alpha", "beta")), + TFLITE_NO_INDICES, + ), + BuiltinOperator.LOGISTIC: (Op.Sigmoid, None, TFLITE_IFM_INDICES), + BuiltinOperator.LSH_PROJECTION: ( + Op.LSHProjection, + OptionsSerializer("LSHProjectionOptions", ("type",)), + TFLITE_NO_INDICES, + ), + BuiltinOperator.LSTM: (Op.Lstm, lstm_opts, TFLITE_IFM_WEIGHTS_INDICES), + BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts, TFLITE_IFM_INDICES), + BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,)), TFLITE_IFM_IFM2_INDICES), + BuiltinOperator.RELU: (Op.Relu, None, TFLITE_IFM_INDICES), + BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None, TFLITE_IFM_INDICES), + BuiltinOperator.RELU6: (Op.Relu6, None, TFLITE_IFM_INDICES), + BuiltinOperator.RESHAPE: ( + Op.Reshape, + OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),)), + TFLITE_IFM_INDICES, ), - BuiltinOperator.LOGISTIC: (Op.Sigmoid, None), - BuiltinOperator.LSH_PROJECTION: (Op.LSHProjection, OptionsSerializer("LSHProjectionOptions", ("type",))), - BuiltinOperator.LSTM: (Op.Lstm, lstm_opts), - BuiltinOperator.MAX_POOL_2D: (Op.MaxPool, pool2d_opts), - BuiltinOperator.MUL: (Op.Mul, OptionsSerializer("MulOptions", (fused_act,))), - BuiltinOperator.RELU: (Op.Relu, None), - BuiltinOperator.RELU_N1_TO_1: (Op.ReluN1To1, None), - BuiltinOperator.RELU6: (Op.Relu6, None), - BuiltinOperator.RESHAPE: (Op.Reshape, OptionsSerializer("ReshapeOptions", (("new_shape", is_int_vec),))), BuiltinOperator.RESIZE_BILINEAR: ( Op.ResizeBilinear, OptionsSerializer("ResizeBilinearOptions", ("align_corners", "half_pixel_centers")), + TFLITE_IFM_INDICES, + ), + BuiltinOperator.RNN: (Op.Rnn, rnn_opts, TFLITE_IFM_WEIGHTS_INDICES), + BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",)), TFLITE_IFM_INDICES), + BuiltinOperator.SPACE_TO_DEPTH: ( + Op.SpaceToDepth, + OptionsSerializer("SpaceToDepthOptions", ("block_size",)), + TFLITE_NO_INDICES, ), - BuiltinOperator.RNN: (Op.Rnn, rnn_opts), - BuiltinOperator.SOFTMAX: (Op.Softmax, OptionsSerializer("SoftmaxOptions", ("beta",))), - BuiltinOperator.SPACE_TO_DEPTH: (Op.SpaceToDepth, OptionsSerializer("SpaceToDepthOptions", ("block_size",))), BuiltinOperator.SVDF: ( Op.Svdf, OptionsSerializer("SVDFOptions", ("rank", fused_act, "asymmetric_quantize_inputs")), + TFLITE_NO_INDICES, ), - BuiltinOperator.TANH: (Op.Tanh, None), + BuiltinOperator.TANH: (Op.Tanh, None, TFLITE_IFM_INDICES), BuiltinOperator.CONCAT_EMBEDDINGS: ( Op.ConcatEmbeddings, OptionsSerializer( @@ -547,40 +587,76 @@ builtin_operator_map = { "embedding_dim_per_channel_as_length", ), ), + TFLITE_NO_INDICES, ), BuiltinOperator.SKIP_GRAM: ( Op.SkipGram, OptionsSerializer("SkipGramOptions", ("ngram_size", "max_skip_size", "include_all_ngrams")), + TFLITE_NO_INDICES, ), - BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",))), + BuiltinOperator.CALL: (Op.Call, OptionsSerializer("CallOptions", ("subgraph",)), TFLITE_NO_INDICES), BuiltinOperator.EMBEDDING_LOOKUP_SPARSE: ( Op.EmbeddingLookupSparse, OptionsSerializer("EmbeddingLookupSparseOptions", ("combiner",)), + TFLITE_NO_INDICES, + ), + BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: ( + Op.UnidirectionalSequenceRnn, + seq_rnn_opts, + TFLITE_IFM_WEIGHTS_INDICES, + ), + BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",)), TFLITE_NO_INDICES), + BuiltinOperator.BATCH_TO_SPACE_ND: ( + Op.BatchToSpaceND, + OptionsSerializer("BatchToSpaceNDOptions"), + TFLITE_NO_INDICES, + ), + BuiltinOperator.SPACE_TO_BATCH_ND: ( + Op.SpaceToBatchND, + OptionsSerializer("SpaceToBatchNDOptions"), + TFLITE_NO_INDICES, + ), + BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions"), TFLITE_NO_INDICES), + BuiltinOperator.MEAN: (Op.Mean, reducer_opts, TFLITE_IFM_INDICES), + BuiltinOperator.SUB: ( + Op.Sub, + OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",)), + TFLITE_IFM_IFM2_INDICES, + ), + BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,)), TFLITE_NO_INDICES), + BuiltinOperator.SQUEEZE: ( + Op.Squeeze, + OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),)), + TFLITE_IFM_INDICES, + ), + BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: ( + Op.UnidirectionalSequenceLstm, + unidir_seq_lstm_opts, + TFLITE_IFM_WEIGHTS_INDICES, ), - BuiltinOperator.PAD: (Op.Pad, OptionsSerializer("PadOptions")), - BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_RNN: (Op.UnidirectionalSequenceRnn, seq_rnn_opts), - BuiltinOperator.GATHER: (Op.GatherV2, OptionsSerializer("GatherOptions", ("axis",))), - BuiltinOperator.BATCH_TO_SPACE_ND: (Op.BatchToSpaceND, OptionsSerializer("BatchToSpaceNDOptions")), - BuiltinOperator.SPACE_TO_BATCH_ND: (Op.SpaceToBatchND, OptionsSerializer("SpaceToBatchNDOptions")), - BuiltinOperator.TRANSPOSE: (Op.Transpose, OptionsSerializer("TransposeOptions")), - BuiltinOperator.MEAN: (Op.Mean, reducer_opts), - BuiltinOperator.SUB: (Op.Sub, OptionsSerializer("SubOptions", (fused_act, "pot_scale_int16",))), - BuiltinOperator.DIV: (Op.Div, OptionsSerializer("DivOptions", (fused_act,))), - BuiltinOperator.SQUEEZE: (Op.Squeeze, OptionsSerializer("SqueezeOptions", (("squeeze_dims", is_int_vec),))), - BuiltinOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: (Op.UnidirectionalSequenceLstm, unidir_seq_lstm_opts), BuiltinOperator.STRIDED_SLICE: ( Op.StridedSlice, OptionsSerializer( - "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask") + "StridedSliceOptions", ("begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask", "shrink_axis_mask"), ), + TFLITE_IFM_INDICES, + ), + BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: ( + Op.BidirectionalSequenceRnn, + bidir_seq_rnn_opts, + TFLITE_IFM_WEIGHTS_INDICES, + ), + BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options"), TFLITE_NO_INDICES), + BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",)), TFLITE_SPLIT_IFM_INDICES), + BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions"), TFLITE_NO_INDICES), + BuiltinOperator.DELEGATE: (Op.Delegate, None, TFLITE_NO_INDICES), + BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: ( + Op.BidirectionalSequenceLstm, + bidir_seq_lstm_opts, + TFLITE_IFM_WEIGHTS_INDICES, ), - BuiltinOperator.BIDIRECTIONAL_SEQUENCE_RNN: (Op.BidirectionalSequenceRnn, bidir_seq_rnn_opts), - BuiltinOperator.EXP: (Op.Exp, OptionsSerializer("ExpOptions")), - BuiltinOperator.TOPK_V2: (Op.TopKV2, OptionsSerializer("TopKV2Options")), - BuiltinOperator.SPLIT: (Op.Split, OptionsSerializer("SplitOptions", ("num_splits",))), - BuiltinOperator.LOG_SOFTMAX: (Op.LogSoftmax, OptionsSerializer("LogSoftmaxOptions")), - BuiltinOperator.DELEGATE: (Op.Delegate, None), - BuiltinOperator.BIDIRECTIONAL_SEQUENCE_LSTM: (Op.BidirectionalSequenceLstm, bidir_seq_lstm_opts), BuiltinOperator.CAST: ( Op.Cast, OptionsSerializer( @@ -590,117 +666,152 @@ builtin_operator_map = { ("out_data_type", datatype_deserialize, datatype_serialize), ), ), + TFLITE_NO_INDICES, ), - BuiltinOperator.PRELU: (Op.Prelu, None), - BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions")), + BuiltinOperator.PRELU: (Op.Prelu, None, TFLITE_NO_INDICES), + BuiltinOperator.MAXIMUM: (Op.Maximum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES), BuiltinOperator.ARG_MAX: ( Op.ArgMax, OptionsSerializer("ArgMaxOptions", (("output_type", datatype_deserialize, datatype_serialize),)), + TFLITE_NO_INDICES, ), - BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions")), - BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions")), - BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions")), - BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options")), - BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions")), - BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions")), - BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions")), - BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions")), - BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions")), - BuiltinOperator.SIN: (Op.Sin, None), + BuiltinOperator.MINIMUM: (Op.Minimum, OptionsSerializer("MaximumMinimumOptions"), TFLITE_IFM_IFM2_INDICES), + BuiltinOperator.LESS: (Op.Less, OptionsSerializer("LessOptions"), TFLITE_NO_INDICES), + BuiltinOperator.NEG: (Op.Neg, OptionsSerializer("NegOptions"), TFLITE_NO_INDICES), + BuiltinOperator.PADV2: (Op.PadV2, OptionsSerializer("PadV2Options"), TFLITE_NO_INDICES), + BuiltinOperator.GREATER: (Op.Greater, OptionsSerializer("GreaterOptions"), TFLITE_NO_INDICES), + BuiltinOperator.GREATER_EQUAL: (Op.GreaterEqual, OptionsSerializer("GreaterEqualOptions"), TFLITE_NO_INDICES), + BuiltinOperator.LESS_EQUAL: (Op.LessEqual, OptionsSerializer("LessEqualOptions"), TFLITE_NO_INDICES), + BuiltinOperator.SELECT: (Op.Select, OptionsSerializer("SelectOptions"), TFLITE_NO_INDICES), + BuiltinOperator.SLICE: (Op.Slice, OptionsSerializer("SliceOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.SIN: (Op.Sin, None, TFLITE_NO_INDICES), BuiltinOperator.TRANSPOSE_CONV: ( Op.Conv2DBackpropInput, OptionsSerializer("TransposeConvOptions", (padding, "stride_w", "stride_h")), + TFLITE_CONV2D_BACKPROP_INDICES, ), BuiltinOperator.SPARSE_TO_DENSE: ( Op.SparseToDense, OptionsSerializer("SparseToDenseOptions", ("validate_indices",)), + TFLITE_NO_INDICES, ), - BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions")), - BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions")), - BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions")), - BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions")), - BuiltinOperator.LOG: (Op.Log, None), - BuiltinOperator.SUM: (Op.Sum, reducer_opts), - BuiltinOperator.SQRT: (Op.Sqrt, None), - BuiltinOperator.RSQRT: (Op.Rsqrt, None), + BuiltinOperator.TILE: (Op.Tile, OptionsSerializer("TileOptions"), TFLITE_NO_INDICES), + BuiltinOperator.EXPAND_DIMS: (Op.ExpandDims, OptionsSerializer("ExpandDimsOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions"), TFLITE_NO_INDICES), + BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions"), TFLITE_NO_INDICES), + BuiltinOperator.LOG: (Op.Log, None, TFLITE_NO_INDICES), + BuiltinOperator.SUM: (Op.Sum, reducer_opts, TFLITE_NO_INDICES), + BuiltinOperator.SQRT: (Op.Sqrt, None, TFLITE_NO_INDICES), + BuiltinOperator.RSQRT: (Op.Rsqrt, None, TFLITE_NO_INDICES), BuiltinOperator.SHAPE: ( Op.Shape, OptionsSerializer("ShapeOptions", (("out_type", datatype_deserialize, datatype_serialize),)), + TFLITE_NO_INDICES, ), - BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions")), + BuiltinOperator.POW: (Op.Pow, OptionsSerializer("PowOptions"), TFLITE_NO_INDICES), BuiltinOperator.ARG_MIN: ( Op.ArgMin, OptionsSerializer("ArgMinOptions", (("output_type", datatype_deserialize, datatype_serialize),)), + TFLITE_NO_INDICES, ), BuiltinOperator.FAKE_QUANT: ( Op.FakeQuantWithMinMaxArgs, OptionsSerializer("FakeQuantOptions", ("min", "max", "num_bits", "narrow_range")), + TFLITE_NO_INDICES, ), - BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts), - BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts), - BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis"))), - BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions")), - BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",))), - BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions")), - BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions")), - BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis"))), - BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts), - BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions")), - BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts), - BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions")), - BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions")), - BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions")), - BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions")), - BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions")), + BuiltinOperator.REDUCE_PROD: (Op.Prod, reducer_opts, TFLITE_NO_INDICES), + BuiltinOperator.REDUCE_MAX: (Op.Max, reducer_opts, TFLITE_NO_INDICES), + BuiltinOperator.PACK: (Op.Pack, OptionsSerializer("PackOptions", ("values_count", "axis")), TFLITE_IFM_INDICES), + BuiltinOperator.LOGICAL_OR: (Op.LogicalOr, OptionsSerializer("LogicalOrOptions"), TFLITE_NO_INDICES), + BuiltinOperator.ONE_HOT: (Op.OneHot, OptionsSerializer("OneHotOptions", ("axis",)), TFLITE_NO_INDICES), + BuiltinOperator.LOGICAL_AND: (Op.LogicalAnd, OptionsSerializer("LogicalAndOptions"), TFLITE_NO_INDICES), + BuiltinOperator.LOGICAL_NOT: (Op.LogicalNot, OptionsSerializer("LogicalNotOptions"), TFLITE_NO_INDICES), + BuiltinOperator.UNPACK: (Op.Unpack, OptionsSerializer("UnpackOptions", ("num", "axis")), TFLITE_IFM_INDICES), + BuiltinOperator.REDUCE_MIN: (Op.Min, reducer_opts, TFLITE_NO_INDICES), + BuiltinOperator.FLOOR_DIV: (Op.FloorDiv, OptionsSerializer("FloorDivOptions"), TFLITE_NO_INDICES), + BuiltinOperator.REDUCE_ANY: (Op.Any, reducer_opts, TFLITE_NO_INDICES), + BuiltinOperator.SQUARE: (Op.Square, OptionsSerializer("SquareOptions"), TFLITE_NO_INDICES), + BuiltinOperator.ZEROS_LIKE: (Op.ZerosLike, OptionsSerializer("ZerosLikeOptions"), TFLITE_NO_INDICES), + BuiltinOperator.FILL: (Op.Fill, OptionsSerializer("FillOptions"), TFLITE_NO_INDICES), + BuiltinOperator.FLOOR_MOD: (Op.FloorMod, OptionsSerializer("FloorModOptions"), TFLITE_NO_INDICES), + BuiltinOperator.RANGE: (Op.Range, OptionsSerializer("RangeOptions"), TFLITE_NO_INDICES), BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: ( Op.ResizeNearestNeighbor, OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")), + TFLITE_NO_INDICES, + ), + BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",)), TFLITE_IFM_INDICES), + BuiltinOperator.SQUARED_DIFFERENCE: ( + Op.SquaredDifference, + OptionsSerializer("SquaredDifferenceOptions"), + TFLITE_NO_INDICES, ), - BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",))), - BuiltinOperator.SQUARED_DIFFERENCE: (Op.SquaredDifference, OptionsSerializer("SquaredDifferenceOptions")), - BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",))), - BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions")), - BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",))), + BuiltinOperator.MIRROR_PAD: (Op.MirrorPad, OptionsSerializer("MirrorPadOptions", ("mode",)), TFLITE_NO_INDICES), + BuiltinOperator.ABS: (Op.Abs, OptionsSerializer("AbsOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.SPLIT_V: (Op.SplitV, OptionsSerializer("SplitVOptions", ("num_splits",)), TFLITE_IFM_INDICES), BuiltinOperator.UNIQUE: ( Op.Unique, OptionsSerializer("UniqueOptions", (("idx_out_type", datatype_deserialize, datatype_serialize),)), + TFLITE_NO_INDICES, ), - BuiltinOperator.CEIL: (Op.Ceil, None), - BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options")), - BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions")), - BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions")), - BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions")), - BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions")), - BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions")), - BuiltinOperator.ELU: (Op.Elu, None), + BuiltinOperator.CEIL: (Op.Ceil, None, TFLITE_NO_INDICES), + BuiltinOperator.REVERSE_V2: (Op.ReverseV2, OptionsSerializer("ReverseV2Options"), TFLITE_NO_INDICES), + BuiltinOperator.ADD_N: (Op.AddN, OptionsSerializer("AddNOptions"), TFLITE_NO_INDICES), + BuiltinOperator.GATHER_ND: (Op.GatherNd, OptionsSerializer("GatherNdOptions"), TFLITE_NO_INDICES), + BuiltinOperator.COS: (Op.Cos, OptionsSerializer("CosOptions"), TFLITE_NO_INDICES), + BuiltinOperator.WHERE: (Op.Where, OptionsSerializer("WhereOptions"), TFLITE_NO_INDICES), + BuiltinOperator.RANK: (Op.Rank, OptionsSerializer("RankOptions"), TFLITE_NO_INDICES), + BuiltinOperator.ELU: (Op.Elu, None, TFLITE_NO_INDICES), BuiltinOperator.REVERSE_SEQUENCE: ( Op.ReverseSequence, OptionsSerializer("ReverseSequenceOptions", ("seq_dim", "batch_dim")), + TFLITE_NO_INDICES, + ), + BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions"), TFLITE_NO_INDICES), + BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions"), TFLITE_NO_INDICES), + BuiltinOperator.ROUND: (Op.Round, None, TFLITE_NO_INDICES), + BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions"), TFLITE_IFM_INDICES), + BuiltinOperator.IF: ( + Op.If, + OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index")), + TFLITE_NO_INDICES, ), - BuiltinOperator.MATRIX_DIAG: (Op.MatrixDiag, OptionsSerializer("MatrixDiagOptions")), - BuiltinOperator.QUANTIZE: (Op.Quantize, OptionsSerializer("QuantizeOptions")), - BuiltinOperator.MATRIX_SET_DIAG: (Op.MatrixSetDiag, OptionsSerializer("MatrixSetDiagOptions")), - BuiltinOperator.ROUND: (Op.Round, None), - BuiltinOperator.HARD_SWISH: (Op.HardSwish, OptionsSerializer("HardSwishOptions")), - BuiltinOperator.IF: (Op.If, OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))), BuiltinOperator.WHILE: ( Op.While, OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index")), + TFLITE_NO_INDICES, + ), + BuiltinOperator.NON_MAX_SUPPRESSION_V4: ( + Op.NonMaxSuppressionV4, + OptionsSerializer("NonMaxSuppressionV4Options"), + TFLITE_NO_INDICES, + ), + BuiltinOperator.NON_MAX_SUPPRESSION_V5: ( + Op.NonMaxSuppressionV5, + OptionsSerializer("NonMaxSuppressionV5Options"), + TFLITE_NO_INDICES, + ), + BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions"), TFLITE_NO_INDICES), + BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options"), TFLITE_NO_INDICES), + BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions"), TFLITE_NO_INDICES), + BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions"), TFLITE_NO_INDICES), + BuiltinOperator.BATCH_MATMUL: ( + Op.BatchMatMul, + OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y")), + TFLITE_NO_INDICES, + ), + BuiltinOperator.CUMSUM: ( + Op.Cumsum, + OptionsSerializer("CumsumOptions", ("exclusive", "reverse")), + TFLITE_NO_INDICES, ), - BuiltinOperator.NON_MAX_SUPPRESSION_V4: (Op.NonMaxSuppressionV4, OptionsSerializer("NonMaxSuppressionV4Options")), - BuiltinOperator.NON_MAX_SUPPRESSION_V5: (Op.NonMaxSuppressionV5, OptionsSerializer("NonMaxSuppressionV5Options")), - BuiltinOperator.SCATTER_ND: (Op.ScatterNd, OptionsSerializer("ScatterNdOptions")), - BuiltinOperator.SELECT_V2: (Op.SelectV2, OptionsSerializer("SelectV2Options")), - BuiltinOperator.DENSIFY: (Op.Densify, OptionsSerializer("DensifyOptions")), - BuiltinOperator.SEGMENT_SUM: (Op.SegmentSum, OptionsSerializer("SegmentSumOptions")), - BuiltinOperator.BATCH_MATMUL: (Op.BatchMatMul, OptionsSerializer("BatchMatMulOptions", ("adj_x", "adj_y"))), - BuiltinOperator.CUMSUM: (Op.Cumsum, OptionsSerializer("CumsumOptions", ("exclusive", "reverse"))), - BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer()), + BuiltinOperator.CUSTOM: (Op.Custom, CustomOptionsSerializer(), TFLITE_NO_INDICES), } -builtin_operator_inv_map = {v[0]: (k, v[1]) for k, v in builtin_operator_map.items()} +builtin_operator_inv_map = {v[0]: (k, v[1], v[2]) for k, v in builtin_operator_map.items()} -builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer()) +builtin_operator_inv_map[Op.CustomNpuOp] = (BuiltinOperator.CUSTOM, CustomOptionsSerializer(), TFLITE_NO_INDICES) BUILTIN_OPERATOR_UNKNOWN = "UNKNOWN" diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 1a45a5ee..30bf32af 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -27,6 +27,7 @@ from .nn_graph import Subgraph from .operation import create_activation_function from .operation import Op from .operation import Operation +from .reader_util import align_tensor_indices_to_nng from .reader_util import clone_and_reshape_tensor from .reader_util import decode_str from .reader_util import fixup_tensors @@ -112,7 +113,7 @@ class TFLiteSubgraph: return tens def parse_operator(self, op_index, op_data): - op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()] + op_type, opt_serializer, custom_code, indices = self.graph.operator_codes[op_data.OpcodeIndex()] inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()] outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()] intermediates = [] @@ -122,6 +123,7 @@ class TFLiteSubgraph: name = "unknown_op_name" if len(outputs): name = outputs[0].name + inputs = align_tensor_indices_to_nng(op_type, indices, inputs) op = Operation(op_type, name) op.op_index = op_index op.inputs = inputs @@ -263,11 +265,11 @@ class TFLiteGraph: raise InputFileError( self.name, f"The input file contains operator code '{c}' which is currently not supported" ) - op_type, ser = builtin_operator_map[c] + op_type, ser, indices = builtin_operator_map[c] custom_code = None if c == BuiltinOperator.CUSTOM: custom_code = decode_str(code.CustomCode()) - return op_type, ser, custom_code + return op_type, ser, custom_code, indices def read_tflite(filename, batch_size, feed_dict, output_node_names, initialisation_nodes): diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 8cabb0ac..3701893e 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -24,6 +24,7 @@ from flatbuffers.builder import UOffsetTFlags from .errors import VelaError from .nn_graph import PassPlacement from .operation import Op +from .reader_util import align_inputs_indices from .tensor import MemType from .tensor import TensorPurpose from .tflite import Buffer @@ -38,7 +39,6 @@ from .tflite_mapping import builtin_operator_inv_map from .tflite_mapping import BuiltinOperator from .tflite_mapping import datatype_inv_map - # ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here: tflite_version = 3 @@ -90,6 +90,8 @@ class TFLiteSerialiser: for ps in sg.passes: for op in ps.ops: if op.type not in self.ops_to_ignore: + # swap from nng input indexing to TensorFlow Lite input indexing + self.align_nng_inputs_to_tflite(op) all_ops.append(op) if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op(): # If values are None op has non-constant weights @@ -104,6 +106,11 @@ class TFLiteSerialiser: self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops)) self.operator_code_map = {} + def align_nng_inputs_to_tflite(self, op): + from_indices = op.type.info.indices + _, _, to_indices = builtin_operator_inv_map[op.type] + op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs) + def write_byte_vector(self, v, alignment=1): builder = self.builder builder.StartVector(1, len(v), alignment) @@ -170,13 +177,13 @@ class TFLiteSerialiser: builder = self.builder custom_code_offset = None if op_type == Op.Custom: - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] custom_code_offset = builder.CreateString(custom_code) else: assert ( op_type in builtin_operator_inv_map ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type) - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] if op_type == Op.CustomNpuOp: assert ( diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 94e6f999..fe18ce35 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -60,7 +60,7 @@ def add_padding_fields(op, arch, nng): def rewrite_activation(op, arch, nng): - if not op.type.is_relu_op(): + if op.type not in (Op.ReluN, Op.Clamp): return op ifm = op.ifm @@ -82,7 +82,7 @@ def rewrite_activation(op, arch, nng): if op.ofm.quantization.zero_point is None: op.ofm.quantization.zero_point = zp - if op.type == Op.Clip: + if op.type == Op.Clamp: op.attrs["min"] = op.attrs["min_int"] - zp op.attrs["max"] = op.attrs["max_int"] - zp elif op.type == Op.ReluN: diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index 82f61f7c..312ac92e 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -249,7 +249,7 @@ tosa_operator_map = { # TODO TosaOp.MATMUL: TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info) - TosaOp.CLAMP: (Op.Clip, clamp_attrs, None, TOSA_IFM_INDICES), + TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES), TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.SIGMOID # TODO TosaOp.TANH diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index ac0b3969..e51ead1d 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -25,6 +25,7 @@ from .nn_graph import Graph from .nn_graph import Subgraph from .operation import Op from .operation import Operation +from .reader_util import align_tensor_indices_to_nng from .reader_util import clone_and_reshape_tensor from .reader_util import decode_str from .reader_util import fixup_tensors @@ -104,8 +105,8 @@ class TosaSubgraph: name = "unknown_op_name" if len(outputs): name = outputs[0].name + inputs = align_tensor_indices_to_nng(op_type, indices, inputs) op = Operation(op_type, name) - op.type.info.indices = indices op.op_index = op_index op.inputs = inputs op.outputs = outputs diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index c87d653a..51f80ebd 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -32,7 +32,7 @@ class TosaSupportedOperators: mac_main_ops = convolution_like_ops type_conversion_ops = set((Op.Rescale,)) - relu_ops = set((Op.Clip, Op.ReluN,)) + relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops npu_post_ops = activation_ops -- cgit v1.2.1