aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_mapping.py
diff options
context:
space:
mode:
authorRob Elliott <robert.elliott@arm.com>2023-08-17 14:27:06 +0000
committerRickard Bolin <rickard.bolin@arm.com>2023-08-21 16:14:51 +0000
commit00a15db3e1a188b25065d095152d701f4394cdc5 (patch)
tree96761b9f7ac3ad759f9f0ffbf63a6d0ef115ad14 /ethosu/vela/tosa_mapping.py
parent8ea90edb75e5d2353aa91c264356fc9d460ca308 (diff)
downloadethos-u-vela-00a15db3e1a188b25065d095152d701f4394cdc5.tar.gz
Moving Vela to use TOSA v0.80.0 specification
* Using serialization_lib main branch to update statically copied files sha 5f920211ac23393a7b98a0d358bfbfc3232d5c8f (v0.80.0) * All files within the ethosu/vela/tosa are copied from that revision * Note: hope to move to serialization_lib as a pip module in future * Modified the ethosu/vela/{tosa_mapping,tosa_reader}.py to use v0.80.0 TOSA FlatBuffers implementation * These are the additional changes made to support this new version, with changes in the format of the FlatBuffers file and where various values are stored. Either changing from input to attribute, or moving to different attributes. Signed-off-by: Rob Elliott <robert.elliott@arm.com> Change-Id: I5e1fcc2a9964148619be3477adf1e88e84cbae2d
Diffstat (limited to 'ethosu/vela/tosa_mapping.py')
-rw-r--r--ethosu/vela/tosa_mapping.py124
1 files changed, 66 insertions, 58 deletions
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 97fdd207..2dafd81d 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -28,19 +28,16 @@ from .tosa import AxisAttribute # noqa: F401
from .tosa import ClampAttribute # noqa: F401
from .tosa import CondIfAttribute # noqa: F401
from .tosa import ConvAttribute # noqa: F401
-from .tosa import ConvQuantInfo # noqa: F401
-from .tosa import MatMulQuantInfo # noqa: F401
+from .tosa import FullyConnectedAttribute # noqa: F401
from .tosa import MulAttribute # noqa: F401
-from .tosa import PadQuantInfo # noqa: F401
from .tosa import PoolAttribute # noqa: F401
-from .tosa import ReluNAttribute # noqa: F401
from .tosa import RescaleAttribute # noqa: F401
from .tosa import ReshapeAttribute # noqa: F401
from .tosa import ResizeAttribute # noqa: F401
from .tosa import SliceAttribute # noqa: F401
from .tosa import TileAttribute # noqa: F401
+from .tosa import TransposeAttribute # noqa: F401
from .tosa import TransposeConvAttribute # noqa: F401
-from .tosa import UnaryQuantInfo # noqa: F401
from .tosa import WhileLoopAttribute # noqa: F401
from .tosa.DType import DType
from .tosa.Op import Op as TosaOp
@@ -54,7 +51,7 @@ datatype_map = {
DType.INT16: DataType.int16,
DType.INT32: DataType.int32,
DType.INT48: DataType.int48,
- DType.FLOAT: DataType.float32,
+ DType.FP32: DataType.float32,
}
datatype_map_numpy = {
@@ -63,7 +60,7 @@ datatype_map_numpy = {
DType.INT8: np.int8,
DType.INT16: np.int16,
DType.INT32: np.int32,
- DType.FLOAT: np.float32,
+ DType.FP32: np.float32,
}
@@ -118,65 +115,64 @@ class AttrSerializer:
return attrs
-class QuantSerializer:
- def __init__(self, name, members=None):
- self.name = name
- self.module = globals()[self.name]
- self.cls = getattr(self.module, self.name)
- self.members = []
- if members is not None:
- for mem in members:
- deserialize = identity
- underscore_mem = mem
- camelcase_mem = underscore_to_camel_case(mem)
- self.members.append((underscore_mem, camelcase_mem, deserialize))
-
- def deserialize(self, op_data):
- quant_info_type = op_data.QuantInfoType()
- quant_info = op_data.QuantInfo()
- quant = {}
- if quant_info_type:
- tosa_quant = self.cls()
- tosa_quant.Init(quant_info.Bytes, quant_info.Pos)
- for underscore_mem, camelcase_mem, deserialize in self.members:
- attr = getattr(tosa_quant, camelcase_mem)()
- try:
- quant[underscore_mem] = deserialize(attr)
- except TypeError:
- print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem))
-
- return quant
-
-
is_vec = True
-pool_attrs = AttrSerializer("PoolAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec)))
-conv_attrs = AttrSerializer("ConvAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec)))
+pool_attrs = AttrSerializer(
+ "PoolAttribute",
+ (
+ ("pad", is_vec),
+ ("kernel", is_vec),
+ ("stride", is_vec),
+ ("input_zp"),
+ ("output_zp"),
+ ),
+)
+conv_attrs = AttrSerializer(
+ "ConvAttribute",
+ (
+ ("pad", is_vec),
+ ("stride", is_vec),
+ ("dilation", is_vec),
+ ("input_zp"),
+ ("weight_zp"),
+ ),
+)
+fc_attrs = AttrSerializer("FullyConnectedAttribute", (("input_zp"), ("weight_zp")))
transpose_conv_attrs = AttrSerializer(
- "TransposeConvAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
+ "TransposeConvAttribute",
+ (
+ ("outpad", is_vec),
+ ("stride", is_vec),
+ ("dilation", is_vec),
+ ("out_shape", is_vec),
+ ),
)
-relun_attrs = AttrSerializer("ReluNAttribute", ("max_int"))
+transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),))
axis_attrs = AttrSerializer("AxisAttribute", ("axis",))
reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),))
-slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec)))
+slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec)))
tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),))
resize_attrs = AttrSerializer(
- "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift"))
+ "ResizeAttribute",
+ (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")),
)
clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int")))
rescale_attrs = AttrSerializer(
"RescaleAttribute",
- ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"),
+ (
+ "input_zp",
+ "output_zp",
+ ("multiplier", is_vec),
+ ("shift", is_vec),
+ "scale32",
+ "double_round",
+ "per_channel",
+ ),
)
mul_attrs = AttrSerializer("MulAttribute", ("shift",))
ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",))
condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references
while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references
-unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp"))
-conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp"))
-matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp"))
-pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp",))
-
unsupported_tosa_operators = {
TosaOp.UNKNOWN,
TosaOp.ARGMAX,
@@ -245,16 +241,26 @@ TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
tosa_operator_map = {
# TosaOp.UNKNOWN: (),
# TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None),
- TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, unary_quant_info, TOSA_IFM_INDICES),
- TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, None, TOSA_IFM_INDICES),
+ TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES),
# TODO TosaOp.CONV3D:
- TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
- TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.DEPTHWISE_CONV2D: (
+ Op.DepthwiseConv2DBias,
+ conv_attrs,
+ None,
+ TOSA_IFM_WEIGHTS_BIAS_INDICES,
+ ),
+ TosaOp.FULLY_CONNECTED: (
+ Op.FullyConnected,
+ fc_attrs,
+ None,
+ TOSA_IFM_WEIGHTS_BIAS_INDICES,
+ ),
# TODO TosaOp.MATMUL:
TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES),
- # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, conv_quant_info)
+ # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, None)
TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES),
- TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
+ # TODO: BUG: No longer a relu - presumably a clamp - TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.SIGMOID
# TODO TosaOp.TANH
TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES),
@@ -299,15 +305,17 @@ tosa_operator_map = {
TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES),
# TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute
# Handled as for TFLite for now
- TosaOp.PAD: (Op.Pad, None, pad_quant_info, TOSA_IFM_INDICES),
+ TosaOp.PAD: (Op.Pad, None, None, TOSA_IFM_INDICES),
TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.REVERSE
TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.TILE
TosaOp.TRANSPOSE: (
Op.Transpose,
+ transpose_attrs,
None,
- None,
+ # TODO: why is this IFM2 indices but then overridden to TOSA_IFM_INDICES in _reader?
+ # TOSA_IFM_IFM2_INDICES,
TOSA_IFM_IFM2_INDICES,
), # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute
# TODO TosaOp.GATHER