diff options
Diffstat (limited to 'ethosu/vela/tosa_mapping.py')
-rw-r--r-- | ethosu/vela/tosa_mapping.py | 124 |
1 files changed, 66 insertions, 58 deletions
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index 97fdd20..2dafd81 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -28,19 +28,16 @@ from .tosa import AxisAttribute # noqa: F401 from .tosa import ClampAttribute # noqa: F401 from .tosa import CondIfAttribute # noqa: F401 from .tosa import ConvAttribute # noqa: F401 -from .tosa import ConvQuantInfo # noqa: F401 -from .tosa import MatMulQuantInfo # noqa: F401 +from .tosa import FullyConnectedAttribute # noqa: F401 from .tosa import MulAttribute # noqa: F401 -from .tosa import PadQuantInfo # noqa: F401 from .tosa import PoolAttribute # noqa: F401 -from .tosa import ReluNAttribute # noqa: F401 from .tosa import RescaleAttribute # noqa: F401 from .tosa import ReshapeAttribute # noqa: F401 from .tosa import ResizeAttribute # noqa: F401 from .tosa import SliceAttribute # noqa: F401 from .tosa import TileAttribute # noqa: F401 +from .tosa import TransposeAttribute # noqa: F401 from .tosa import TransposeConvAttribute # noqa: F401 -from .tosa import UnaryQuantInfo # noqa: F401 from .tosa import WhileLoopAttribute # noqa: F401 from .tosa.DType import DType from .tosa.Op import Op as TosaOp @@ -54,7 +51,7 @@ datatype_map = { DType.INT16: DataType.int16, DType.INT32: DataType.int32, DType.INT48: DataType.int48, - DType.FLOAT: DataType.float32, + DType.FP32: DataType.float32, } datatype_map_numpy = { @@ -63,7 +60,7 @@ datatype_map_numpy = { DType.INT8: np.int8, DType.INT16: np.int16, DType.INT32: np.int32, - DType.FLOAT: np.float32, + DType.FP32: np.float32, } @@ -118,65 +115,64 @@ class AttrSerializer: return attrs -class QuantSerializer: - def __init__(self, name, members=None): - self.name = name - self.module = globals()[self.name] - self.cls = getattr(self.module, self.name) - self.members = [] - if members is not None: - for mem in members: - deserialize = identity - underscore_mem = mem - camelcase_mem = underscore_to_camel_case(mem) - self.members.append((underscore_mem, camelcase_mem, deserialize)) - - def deserialize(self, op_data): - quant_info_type = op_data.QuantInfoType() - quant_info = op_data.QuantInfo() - quant = {} - if quant_info_type: - tosa_quant = self.cls() - tosa_quant.Init(quant_info.Bytes, quant_info.Pos) - for underscore_mem, camelcase_mem, deserialize in self.members: - attr = getattr(tosa_quant, camelcase_mem)() - try: - quant[underscore_mem] = deserialize(attr) - except TypeError: - print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem)) - - return quant - - is_vec = True -pool_attrs = AttrSerializer("PoolAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec))) -conv_attrs = AttrSerializer("ConvAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec))) +pool_attrs = AttrSerializer( + "PoolAttribute", + ( + ("pad", is_vec), + ("kernel", is_vec), + ("stride", is_vec), + ("input_zp"), + ("output_zp"), + ), +) +conv_attrs = AttrSerializer( + "ConvAttribute", + ( + ("pad", is_vec), + ("stride", is_vec), + ("dilation", is_vec), + ("input_zp"), + ("weight_zp"), + ), +) +fc_attrs = AttrSerializer("FullyConnectedAttribute", (("input_zp"), ("weight_zp"))) transpose_conv_attrs = AttrSerializer( - "TransposeConvAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec)) + "TransposeConvAttribute", + ( + ("outpad", is_vec), + ("stride", is_vec), + ("dilation", is_vec), + ("out_shape", is_vec), + ), ) -relun_attrs = AttrSerializer("ReluNAttribute", ("max_int")) +transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),)) axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) -slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec))) +slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec))) tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) resize_attrs = AttrSerializer( - "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")) + "ResizeAttribute", + (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")), ) clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int"))) rescale_attrs = AttrSerializer( "RescaleAttribute", - ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"), + ( + "input_zp", + "output_zp", + ("multiplier", is_vec), + ("shift", is_vec), + "scale32", + "double_round", + "per_channel", + ), ) mul_attrs = AttrSerializer("MulAttribute", ("shift",)) ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",)) condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references -unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp")) -conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp")) -matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp")) -pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp",)) - unsupported_tosa_operators = { TosaOp.UNKNOWN, TosaOp.ARGMAX, @@ -245,16 +241,26 @@ TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], []) tosa_operator_map = { # TosaOp.UNKNOWN: (), # TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None), - TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, unary_quant_info, TOSA_IFM_INDICES), - TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), + TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, None, TOSA_IFM_INDICES), + TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES), # TODO TosaOp.CONV3D: - TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), - TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES), + TosaOp.DEPTHWISE_CONV2D: ( + Op.DepthwiseConv2DBias, + conv_attrs, + None, + TOSA_IFM_WEIGHTS_BIAS_INDICES, + ), + TosaOp.FULLY_CONNECTED: ( + Op.FullyConnected, + fc_attrs, + None, + TOSA_IFM_WEIGHTS_BIAS_INDICES, + ), # TODO TosaOp.MATMUL: TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES), - # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, conv_quant_info) + # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, None) TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES), - TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), + # TODO: BUG: No longer a relu - presumably a clamp - TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.SIGMOID # TODO TosaOp.TANH TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES), @@ -299,15 +305,17 @@ tosa_operator_map = { TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES), # TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute # Handled as for TFLite for now - TosaOp.PAD: (Op.Pad, None, pad_quant_info, TOSA_IFM_INDICES), + TosaOp.PAD: (Op.Pad, None, None, TOSA_IFM_INDICES), TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.REVERSE TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES), # TODO TosaOp.TILE TosaOp.TRANSPOSE: ( Op.Transpose, + transpose_attrs, None, - None, + # TODO: why is this IFM2 indices but then overridden to TOSA_IFM_INDICES in _reader? + # TOSA_IFM_IFM2_INDICES, TOSA_IFM_IFM2_INDICES, ), # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute # TODO TosaOp.GATHER |