# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the License); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an AS IS BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Description: # TOSA mapping functions used by reader. # Contains a mapping from the various TOSA enums and options structs, generated by the FlatBuffer code # generator, to Vela's internal format. import numpy as np from .data_type import DataType from .operation import Op from .operation import TensorIndices from .tosa import ArithmeticRightShiftAttribute # noqa: F401 from .tosa import AxisAttribute # noqa: F401 from .tosa import ClampAttribute # noqa: F401 from .tosa import CondIfAttribute # noqa: F401 from .tosa import ConvAttribute # noqa: F401 from .tosa import FullyConnectedAttribute # noqa: F401 from .tosa import MulAttribute # noqa: F401 from .tosa import PoolAttribute # noqa: F401 from .tosa import RescaleAttribute # noqa: F401 from .tosa import ReshapeAttribute # noqa: F401 from .tosa import ResizeAttribute # noqa: F401 from .tosa import SliceAttribute # noqa: F401 from .tosa import TileAttribute # noqa: F401 from .tosa import TransposeAttribute # noqa: F401 from .tosa import TransposeConvAttribute # noqa: F401 from .tosa import WhileLoopAttribute # noqa: F401 from .tosa.DType import DType from .tosa.Op import Op as TosaOp datatype_map = { DType.BOOL: DataType.bool, DType.UINT8: DataType.uint8, DType.INT4: DataType.int4, DType.INT8: DataType.int8, DType.INT16: DataType.int16, DType.INT32: DataType.int32, DType.INT48: DataType.int48, DType.FP32: DataType.float32, } datatype_map_numpy = { DType.BOOL: bool, DType.UINT8: np.uint8, DType.INT8: np.int8, DType.INT16: np.int16, DType.INT32: np.int32, DType.FP32: np.float32, } # TODO duplicate of tflite_mapping def underscore_to_camel_case(s): return "".join(x.title() for x in s.split("_")) # TODO duplicate of tflite_mapping def identity(x): return x class AttrSerializer: def __init__(self, name, members=None): self.name = name self.module = globals()[self.name] self.cls = getattr(self.module, self.name) self.members = [] if members is not None: for mem in members: deserialize = identity is_vector = False if isinstance(mem, tuple): if len(mem) == 2: mem, is_vector = mem deserialize = tuple else: assert 0 underscore_mem = mem camelcase_mem = underscore_to_camel_case(mem) self.members.append((underscore_mem, camelcase_mem, deserialize, is_vector)) def deserialize(self, op_data): attr_type = op_data.AttributeType() attr = op_data.Attribute() attrs = {} if attr_type: tosa_attrs = self.cls() tosa_attrs.Init(attr.Bytes, attr.Pos) for underscore_mem, camelcase_mem, deserialize, is_vector in self.members: fun = camelcase_mem if is_vector: fun += "AsNumpy" attr = getattr(tosa_attrs, fun)() try: attrs[underscore_mem] = deserialize(attr) except TypeError: print("Warning: {0} could not read attribute '{1}'.".format(self.name, underscore_mem)) return attrs is_vec = True pool_attrs = AttrSerializer( "PoolAttribute", ( ("pad", is_vec), ("kernel", is_vec), ("stride", is_vec), ("input_zp"), ("output_zp"), ), ) conv_attrs = AttrSerializer( "ConvAttribute", ( ("pad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("input_zp"), ("weight_zp"), ), ) fc_attrs = AttrSerializer("FullyConnectedAttribute", (("input_zp"), ("weight_zp"))) transpose_conv_attrs = AttrSerializer( "TransposeConvAttribute", ( ("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec), ), ) transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),)) axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec))) tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) resize_attrs = AttrSerializer( "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")), ) clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int"))) rescale_attrs = AttrSerializer( "RescaleAttribute", ( "input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel", ), ) mul_attrs = AttrSerializer("MulAttribute", ("shift",)) ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",)) condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references unsupported_tosa_operators = { TosaOp.UNKNOWN, TosaOp.ARGMAX, TosaOp.CONV3D, TosaOp.MATMUL, TosaOp.TRANSPOSE_CONV2D, TosaOp.SIGMOID, TosaOp.TANH, TosaOp.BITWISE_AND, TosaOp.BITWISE_OR, TosaOp.BITWISE_XOR, TosaOp.INTDIV, TosaOp.LOGICAL_AND, TosaOp.LOGICAL_LEFT_SHIFT, TosaOp.LOGICAL_RIGHT_SHIFT, TosaOp.LOGICAL_OR, TosaOp.LOGICAL_XOR, TosaOp.MAXIMUM, TosaOp.MINIMUM, TosaOp.POW, TosaOp.ABS, TosaOp.BITWISE_NOT, TosaOp.CEIL, TosaOp.CLZ, TosaOp.EXP, TosaOp.FLOOR, TosaOp.LOG, TosaOp.LOGICAL_NOT, TosaOp.NEGATE, TosaOp.RECIPROCAL, TosaOp.RSQRT, TosaOp.SELECT, TosaOp.EQUAL, TosaOp.GREATER, TosaOp.GREATER_EQUAL, TosaOp.REDUCE_ANY, TosaOp.REDUCE_ALL, TosaOp.REDUCE_MAX, TosaOp.REDUCE_MIN, TosaOp.REDUCE_PRODUCT, TosaOp.REDUCE_SUM, TosaOp.REVERSE, TosaOp.TILE, TosaOp.GATHER, TosaOp.SCATTER, TosaOp.RESIZE, TosaOp.CAST, TosaOp.CUSTOM, TosaOp.COND_IF, TosaOp.WHILE_LOOP, } TOSA_NO_INDICES = TensorIndices([], [], []) TOSA_IFM_INDICES = TensorIndices([0], [], []) # TOSA_IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) TOSA_IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) TOSA_IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) # TOSA_CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) # TOSA_TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], []) # TOSA_SPLIT_IFM_INDICES = TensorIndices([1], [], []) # TOSA_BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) tosa_operator_map = { # TosaOp.UNKNOWN: (), # TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None), TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, None, TOSA_IFM_INDICES), TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES), # TODO TosaOp.CONV3D: TosaOp.DEPTHWISE_CONV2D: ( Op.DepthwiseConv2DBias, conv_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES, ), TosaOp.FULLY_CONNECTED: ( Op.FullyConnected, fc_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES, ), # TODO TosaOp.MATMUL: TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, None) TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES), # TODO: BUG: No longer a relu - presumably a clamp - TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.SIGMOID # TODO TosaOp.TANH TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES), TosaOp.ARITHMETIC_RIGHT_SHIFT: (Op.SHR, ars_attrs, None, TOSA_IFM_IFM2_INDICES), # TODO TosaOp.BITWISE_AND # TODO TosaOp.BITWISE_OR # TODO TosaOp.BITWISE_XOR # TODO TosaOp.INTDIV # TODO TosaOp.LOGICAL_AND # TODO TosaOp.LOGICAL_LEFT_SHIFT # TODO TosaOp.LOGICAL_RIGHT_SHIFT # TODO TosaOp.LOGICAL_OR # TODO TosaOp.LOGICAL_XOR # TODO TosaOp.MAXIMUM # TODO TosaOp.MINIMUM TosaOp.MUL: (Op.Mul, mul_attrs, None, TOSA_IFM_IFM2_INDICES), # TODO TosaOp.POW TosaOp.SUB: (Op.Sub, None, None, TOSA_IFM_IFM2_INDICES), # TODO is table content in input[1] always constant? TosaOp.TABLE: (Op.Table, None, None, TOSA_IFM_INDICES), # TODO TosaOp.ABS # TODO TosaOp.BITWISE_NOT # TODO TosaOp.CEIL # TODO TosaOp.CLZ # TODO TosaOp.EXP # TODO TosaOp.FLOOR # TODO TosaOp.LOG # TODO TosaOp.LOGICAL_NOT # TODO TosaOp.NEGATE # TODO TosaOp.RECIPROCAL # TODO TosaOp.RSQRT # TODO TosaOp.SELECT # TODO TosaOp.EQUAL # TODO TosaOp.GREATER # TODO TosaOp.GREATER_EQUAL # TODO TosaOp.REDUCE_ANY # TODO TosaOp.REDUCE_ALL # TODO TosaOp.REDUCE_MAX # TODO TosaOp.REDUCE_MIN # TODO TosaOp.REDUCE_PRODUCT # TODO TosaOp.REDUCE_SUM TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES), # TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute # Handled as for TFLite for now TosaOp.PAD: (Op.Pad, None, None, TOSA_IFM_INDICES), TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.REVERSE TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.TILE TosaOp.TRANSPOSE: ( Op.Transpose, transpose_attrs, None, # TODO: why is this IFM2 indices but then overridden to TOSA_IFM_INDICES in _reader? # TOSA_IFM_IFM2_INDICES, TOSA_IFM_IFM2_INDICES, ), # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute # TODO TosaOp.GATHER # TODO TosaOp.SCATTER # TODO TosaOp.RESIZE # TODO TosaOp.CAST TosaOp.RESCALE: (Op.Rescale, rescale_attrs, None, TOSA_IFM_INDICES), TosaOp.CONST: (Op.Const, None, None, TOSA_NO_INDICES), TosaOp.IDENTITY: (Op.Identity, None, None, TOSA_IFM_INDICES), # TODO TosaOp.CUSTOM # TODO TosaOp.COND_IF # TODO TosaOp.WHILE_LOOP } tosa_operator_inv_map = {v[0]: (k, v[1]) for k, v in tosa_operator_map.items()} tosa_operator_name_map = {v: k for k, v in vars(TosaOp).items()} # TODO will return UNKNOWN for the once that have not yet been defined in tosa_operator_map def optype_to_tosa_op_type(op_type: Op): if op_type in tosa_operator_inv_map: return tosa_operator_name_map[tosa_operator_inv_map[op_type][0]] else: return TosaOp.UNKNOWN