From 8f1f9aaa58175b17cd2e505bfcdb0e40c955ea72 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Mon, 28 Jun 2021 07:41:58 +0200 Subject: MLBEDSW-4838 Added basic TOSA support. Added basic TOSA support, enabling Vela to read and compile a .tosa file corresponding to CONV2D + Rescale + Clamp, and writing it to an optimized .tflite file. The optimized .tflite file, will in this case, hold a commandstream where the Rescale and Clamp has been fused into the CONV2D. The optimized tflite file is not output from Vela. -Added support to read .tosa file into Vela internal structure. - Added tosa_reader.py, tosa_mapper.py and helper files stored under tosa/ - Support for this limited to ~10 ops -Added reader_util.py for functions common for TOSA and TFLite -Added tosa_graph_optimiser.py -Added support to fuse Rescale into convolution -Modified handling for padding -Added support to fuse Clamp to previous op -Added graph_optimiser_util.py -Moved functions common for TOSA/TFLite graph optimization to this file. -Renamed graph_optimiser.py to tflite_graph_optmiser.py -Added separate tosa_supported_operators.py -Added supported_operator_util.py -For functions in common for TOSA/TFLite Signed-off-by: Patrik Gustavsson Change-Id: Ic3c540504ec8c5eb4771397fdc6882050ecf33ab --- ethosu/vela/tosa_mapping.py | 325 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 ethosu/vela/tosa_mapping.py (limited to 'ethosu/vela/tosa_mapping.py') diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py new file mode 100644 index 00000000..82f61f7c --- /dev/null +++ b/ethosu/vela/tosa_mapping.py @@ -0,0 +1,325 @@ +# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# TOSA mapping functions used by reader. +# Contains a mapping from the various TOSA enums and options structs, generated by the FlatBuffer code +# generator, to Vela's internal format. +from .data_type import DataType +from .operation import Op +from .operation import TensorIndices +from .tosa import ArithmeticRightShiftAttribute # noqa: F401 +from .tosa import AxisAttribute # noqa: F401 +from .tosa import ClampAttribute # noqa: F401 +from .tosa import CondIfAttribute # noqa: F401 +from .tosa import Conv2dAttribute # noqa: F401 +from .tosa import ConvQuantInfo # noqa: F401 +from .tosa import MatMulQuantInfo # noqa: F401 +from .tosa import MulAttribute # noqa: F401 +from .tosa import PadQuantInfo # noqa: F401 +from .tosa import Pool2dAttribute # noqa: F401 +from .tosa import ReluNAttribute # noqa: F401 +from .tosa import RescaleAttribute # noqa: F401 +from .tosa import ReshapeAttribute # noqa: F401 +from .tosa import ResizeAttribute # noqa: F401 +from .tosa import SliceAttribute # noqa: F401 +from .tosa import TileAttribute # noqa: F401 +from .tosa import TransposeConv2dAttribute # noqa: F401 +from .tosa import UnaryQuantInfo # noqa: F401 +from .tosa import WhileLoopAttribute # noqa: F401 +from .tosa.DType import DType +from .tosa.Op import Op as TosaOp + + +datatype_map = { + DType.BOOL: DataType.bool, + DType.UINT8: DataType.uint8, + DType.INT4: DataType.int4, + DType.INT8: DataType.int8, + DType.INT16: DataType.int16, + DType.INT32: DataType.int32, + DType.INT48: DataType.int48, + DType.FLOAT: DataType.float32, +} + + +# TODO duplicate of tflite_mapping +def underscore_to_camel_case(s): + return "".join(x.title() for x in s.split("_")) + + +# TODO duplicate of tflite_mapping +def identity(x): + return x + + +class AttrSerializer: + def __init__(self, name, members=None): + self.name = name + self.module = globals()[self.name] + self.cls = getattr(self.module, self.name) + self.members = [] + if members is not None: + for mem in members: + deserialize = identity + is_vector = False + if isinstance(mem, tuple): + if len(mem) == 2: + mem, is_vector = mem + deserialize = tuple + else: + assert 0 + underscore_mem = mem + camelcase_mem = underscore_to_camel_case(mem) + self.members.append((underscore_mem, camelcase_mem, deserialize, is_vector)) + + def deserialize(self, op_data): + attr_type = op_data.AttributeType() + attr = op_data.Attribute() + attrs = {} + if attr_type: + tosa_attrs = self.cls() + tosa_attrs.Init(attr.Bytes, attr.Pos) + for underscore_mem, camelcase_mem, deserialize, is_vector in self.members: + fun = camelcase_mem + if is_vector: + fun += "AsNumpy" + + attr = getattr(tosa_attrs, fun)() + try: + attrs[underscore_mem] = deserialize(attr) + except TypeError: + print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem)) + + return attrs + + +class QuantSerializer: + def __init__(self, name, members=None): + self.name = name + self.module = globals()[self.name] + self.cls = getattr(self.module, self.name) + self.members = [] + if members is not None: + for mem in members: + deserialize = identity + underscore_mem = mem + camelcase_mem = underscore_to_camel_case(mem) + self.members.append((underscore_mem, camelcase_mem, deserialize)) + + def deserialize(self, op_data): + quant_info_type = op_data.QuantInfoType() + quant_info = op_data.QuantInfo() + quant = {} + if quant_info_type: + tosa_quant = self.cls() + tosa_quant.Init(quant_info.Bytes, quant_info.Pos) + for underscore_mem, camelcase_mem, deserialize in self.members: + attr = getattr(tosa_quant, camelcase_mem)() + try: + quant[underscore_mem] = deserialize(attr) + except TypeError: + print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem)) + + return quant + + +is_vec = True +pool2d_attrs = AttrSerializer("Pool2dAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec))) +conv2d_attrs = AttrSerializer("Conv2dAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec))) +transpose_conv2d_attrs = AttrSerializer( + "TransposeConv2dAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec)) +) +relun_attrs = AttrSerializer("ReluNAttribute", ("max_int")) +axis_attrs = AttrSerializer("AxisAttribute", ("axis")) +reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) +slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec))) +tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) +resize_attrs = AttrSerializer( + "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")) +) +clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int"))) +rescale_attrs = AttrSerializer( + "RescaleAttribute", + ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"), +) +mul_attrs = AttrSerializer("MulAttribute", ("shift")) +ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",)) +condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references +while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references + +unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp")) +conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp")) +matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp")) +pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp")) + +unsupported_tosa_operators = { + TosaOp.UNKNOWN, + TosaOp.ARGMAX, + TosaOp.CONV3D, + TosaOp.MATMUL, + TosaOp.TRANSPOSE_CONV2D, + TosaOp.SIGMOID, + TosaOp.TANH, + TosaOp.BITWISE_AND, + TosaOp.BITWISE_OR, + TosaOp.BITWISE_XOR, + TosaOp.DIV, + TosaOp.LOGICAL_AND, + TosaOp.LOGICAL_LEFT_SHIFT, + TosaOp.LOGICAL_RIGHT_SHIFT, + TosaOp.LOGICAL_OR, + TosaOp.LOGICAL_XOR, + TosaOp.MAXIMUM, + TosaOp.MINIMUM, + TosaOp.MUL, + TosaOp.POW, + TosaOp.TABLE, + TosaOp.ABS, + TosaOp.BITWISE_NOT, + TosaOp.CEIL, + TosaOp.CLZ, + TosaOp.EXP, + TosaOp.FLOOR, + TosaOp.LOG, + TosaOp.LOGICAL_NOT, + TosaOp.NEGATE, + TosaOp.RECIPROCAL, + TosaOp.RSQRT, + TosaOp.SELECT, + TosaOp.EQUAL, + TosaOp.GREATER, + TosaOp.GREATER_EQUAL, + TosaOp.REDUCE_ANY, + TosaOp.REDUCE_ALL, + TosaOp.REDUCE_MAX, + TosaOp.REDUCE_MIN, + TosaOp.REDUCE_PRODUCT, + TosaOp.REDUCE_SUM, + TosaOp.CONCAT, + TosaOp.PAD, + TosaOp.RESHAPE, + TosaOp.REVERSE, + TosaOp.SLICE, + TosaOp.TILE, + TosaOp.TRANSPOSE, + TosaOp.GATHER, + TosaOp.SCATTER, + TosaOp.RESIZE, + TosaOp.CAST, + TosaOp.IDENTITY, + TosaOp.CUSTOM, + TosaOp.COND_IF, + TosaOp.WHILE_LOOP, +} + + +TOSA_NO_INDICES = TensorIndices([], [], []) +TOSA_IFM_INDICES = TensorIndices([0], [], []) +# TOSA_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) +TOSA_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) +TOSA_IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) +# TOSA_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) +# TOSA_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) +# TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], []) +# TOSA_SPLIT_IFM_INDICES = TensorIndices([1], [], []) +# TOSA_BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) + + +tosa_operator_map = { + # TosaOp.UNKNOWN: (), + # TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None), + TosaOp.AVG_POOL2D: (Op.AvgPool, pool2d_attrs, unary_quant_info, TOSA_IFM_INDICES), + TosaOp.CONV2D: (Op.Conv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), + # TODO TosaOp.CONV3D: + TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), + TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), + # TODO TosaOp.MATMUL: + TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES), + # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info) + TosaOp.CLAMP: (Op.Clip, clamp_attrs, None, TOSA_IFM_INDICES), + TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), + # TODO TosaOp.SIGMOID + # TODO TosaOp.TANH + TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES), + TosaOp.ARITHMETIC_RIGHT_SHIFT: (Op.SHR, ars_attrs, None, TOSA_IFM_IFM2_INDICES), + # TODO TosaOp.BITWISE_AND + # TODO TosaOp.BITWISE_OR + # TODO TosaOp.BITWISE_XOR + # TODO TosaOp.DIV + # TODO TosaOp.LOGICAL_AND + # TODO TosaOp.LOGICAL_LEFT_SHIFT + # TODO TosaOp.LOGICAL_RIGHT_SHIFT + # TODO TosaOp.LOGICAL_OR + # TODO TosaOp.LOGICAL_XOR + # TODO TosaOp.MAXIMUM + # TODO TosaOp.MINIMUM + # TODO TosaOp.MUL + # TODO TosaOp.POW + TosaOp.SUB: (Op.Sub, None, None, TOSA_IFM_IFM2_INDICES), + # TODO TosaOp.TABLE + # TODO TosaOp.ABS + # TODO TosaOp.BITWISE_NOT + # TODO TosaOp.CEIL + # TODO TosaOp.CLZ + # TODO TosaOp.EXP + # TODO TosaOp.FLOOR + # TODO TosaOp.LOG + # TODO TosaOp.LOGICAL_NOT + # TODO TosaOp.NEGATE + # TODO TosaOp.RECIPROCAL + # TODO TosaOp.RSQRT + # TODO TosaOp.SELECT + # TODO TosaOp.EQUAL + # TODO TosaOp.GREATER + # TODO TosaOp.GREATER_EQUAL + # TODO TosaOp.REDUCE_ANY + # TODO TosaOp.REDUCE_ALL + # TODO TosaOp.REDUCE_MAX + # TODO TosaOp.REDUCE_MIN + # TODO TosaOp.REDUCE_PRODUCT + # TODO TosaOp.REDUCE_SUM + # TODO TosaOp.CONCAT + # TODO TosaOp.PAD + # TODO TosaOp.RESHAPE + # TODO TosaOp.REVERSE + # TODO TosaOp.SLICE + # TODO TosaOp.TILE + # TODO TosaOp.TRANSPOSE + # TODO TosaOp.GATHER + # TODO TosaOp.SCATTER + # TODO TosaOp.RESIZE + # TODO TosaOp.CAST + TosaOp.RESCALE: (Op.Rescale, rescale_attrs, None, TOSA_IFM_INDICES), + TosaOp.CONST: (Op.Const, None, None, TOSA_NO_INDICES), + # TODO TosaOp.IDENTITY + # TODO TosaOp.CUSTOM + # TODO TosaOp.COND_IF + # TODO TosaOp.WHILE_LOOP +} + +tosa_operator_inv_map = {v[0]: (k, v[1]) for k, v in tosa_operator_map.items()} + + +def tosa_type_name(builtin): + return next(k for k, v in vars(TosaOp).items() if v == builtin) + + +# TODO will return UNKNOWN for the once that have not yet been defined in tosa_operator_map +def optype_to_tosa_op_type(op_type): + if op_type in tosa_operator_inv_map: + return tosa_type_name(tosa_operator_inv_map[op_type][0]) + else: + return TosaOp.UNKNOWN -- cgit v1.2.1