aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_mapping.py')
-rw-r--r--ethosu/vela/tosa_mapping.py124
1 files changed, 66 insertions, 58 deletions
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 97fdd20..2dafd81 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -28,19 +28,16 @@ from .tosa import AxisAttribute # noqa: F401
from .tosa import ClampAttribute # noqa: F401
from .tosa import CondIfAttribute # noqa: F401
from .tosa import ConvAttribute # noqa: F401
-from .tosa import ConvQuantInfo # noqa: F401
-from .tosa import MatMulQuantInfo # noqa: F401
+from .tosa import FullyConnectedAttribute # noqa: F401
from .tosa import MulAttribute # noqa: F401
-from .tosa import PadQuantInfo # noqa: F401
from .tosa import PoolAttribute # noqa: F401
-from .tosa import ReluNAttribute # noqa: F401
from .tosa import RescaleAttribute # noqa: F401
from .tosa import ReshapeAttribute # noqa: F401
from .tosa import ResizeAttribute # noqa: F401
from .tosa import SliceAttribute # noqa: F401
from .tosa import TileAttribute # noqa: F401
+from .tosa import TransposeAttribute # noqa: F401
from .tosa import TransposeConvAttribute # noqa: F401
-from .tosa import UnaryQuantInfo # noqa: F401
from .tosa import WhileLoopAttribute # noqa: F401
from .tosa.DType import DType
from .tosa.Op import Op as TosaOp
@@ -54,7 +51,7 @@ datatype_map = {
DType.INT16: DataType.int16,
DType.INT32: DataType.int32,
DType.INT48: DataType.int48,
- DType.FLOAT: DataType.float32,
+ DType.FP32: DataType.float32,
}
datatype_map_numpy = {
@@ -63,7 +60,7 @@ datatype_map_numpy = {
DType.INT8: np.int8,
DType.INT16: np.int16,
DType.INT32: np.int32,
- DType.FLOAT: np.float32,
+ DType.FP32: np.float32,
}
@@ -118,65 +115,64 @@ class AttrSerializer:
return attrs
-class QuantSerializer:
- def __init__(self, name, members=None):
- self.name = name
- self.module = globals()[self.name]
- self.cls = getattr(self.module, self.name)
- self.members = []
- if members is not None:
- for mem in members:
- deserialize = identity
- underscore_mem = mem
- camelcase_mem = underscore_to_camel_case(mem)
- self.members.append((underscore_mem, camelcase_mem, deserialize))
-
- def deserialize(self, op_data):
- quant_info_type = op_data.QuantInfoType()
- quant_info = op_data.QuantInfo()
- quant = {}
- if quant_info_type:
- tosa_quant = self.cls()
- tosa_quant.Init(quant_info.Bytes, quant_info.Pos)
- for underscore_mem, camelcase_mem, deserialize in self.members:
- attr = getattr(tosa_quant, camelcase_mem)()
- try:
- quant[underscore_mem] = deserialize(attr)
- except TypeError:
- print("Warning: {0} could not read quant info '{1}'.".format(self.name, underscore_mem))
-
- return quant
-
-
is_vec = True
-pool_attrs = AttrSerializer("PoolAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec)))
-conv_attrs = AttrSerializer("ConvAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec)))
+pool_attrs = AttrSerializer(
+ "PoolAttribute",
+ (
+ ("pad", is_vec),
+ ("kernel", is_vec),
+ ("stride", is_vec),
+ ("input_zp"),
+ ("output_zp"),
+ ),
+)
+conv_attrs = AttrSerializer(
+ "ConvAttribute",
+ (
+ ("pad", is_vec),
+ ("stride", is_vec),
+ ("dilation", is_vec),
+ ("input_zp"),
+ ("weight_zp"),
+ ),
+)
+fc_attrs = AttrSerializer("FullyConnectedAttribute", (("input_zp"), ("weight_zp")))
transpose_conv_attrs = AttrSerializer(
- "TransposeConvAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
+ "TransposeConvAttribute",
+ (
+ ("outpad", is_vec),
+ ("stride", is_vec),
+ ("dilation", is_vec),
+ ("out_shape", is_vec),
+ ),
)
-relun_attrs = AttrSerializer("ReluNAttribute", ("max_int"))
+transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),))
axis_attrs = AttrSerializer("AxisAttribute", ("axis",))
reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),))
-slice_attrs = AttrSerializer("SliceAttribute", (("begin", is_vec), ("size", is_vec)))
+slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec)))
tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),))
resize_attrs = AttrSerializer(
- "ResizeAttribute", (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift"))
+ "ResizeAttribute",
+ (("output_size", is_vec), ("stride", is_vec), ("offset", is_vec), ("shift")),
)
clamp_attrs = AttrSerializer("ClampAttribute", (("min_int"), ("max_int")))
rescale_attrs = AttrSerializer(
"RescaleAttribute",
- ("input_zp", "output_zp", ("multiplier", is_vec), ("shift", is_vec), "scale32", "double_round", "per_channel"),
+ (
+ "input_zp",
+ "output_zp",
+ ("multiplier", is_vec),
+ ("shift", is_vec),
+ "scale32",
+ "double_round",
+ "per_channel",
+ ),
)
mul_attrs = AttrSerializer("MulAttribute", ("shift",))
ars_attrs = AttrSerializer("ArithmeticRightShiftAttribute", ("round",))
condif_attrs = AttrSerializer("CondIfAttribute", (("then_branch"), ("else_branch"))) # TODO these are references
while_attrs = AttrSerializer("WhileLoopAttribute", (("cond_branch"), ("body_branch"))) # TODO these are references
-unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp"))
-conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp"))
-matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp"))
-pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp",))
-
unsupported_tosa_operators = {
TosaOp.UNKNOWN,
TosaOp.ARGMAX,
@@ -245,16 +241,26 @@ TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
tosa_operator_map = {
# TosaOp.UNKNOWN: (),
# TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None),
- TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, unary_quant_info, TOSA_IFM_INDICES),
- TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, None, TOSA_IFM_INDICES),
+ TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, None, TOSA_IFM_WEIGHTS_BIAS_INDICES),
# TODO TosaOp.CONV3D:
- TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
- TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.DEPTHWISE_CONV2D: (
+ Op.DepthwiseConv2DBias,
+ conv_attrs,
+ None,
+ TOSA_IFM_WEIGHTS_BIAS_INDICES,
+ ),
+ TosaOp.FULLY_CONNECTED: (
+ Op.FullyConnected,
+ fc_attrs,
+ None,
+ TOSA_IFM_WEIGHTS_BIAS_INDICES,
+ ),
# TODO TosaOp.MATMUL:
TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES),
- # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, conv_quant_info)
+ # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, None)
TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES),
- TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
+ # TODO: BUG: No longer a relu - presumably a clamp - TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.SIGMOID
# TODO TosaOp.TANH
TosaOp.ADD: (Op.Add, None, None, TOSA_IFM_IFM2_INDICES),
@@ -299,15 +305,17 @@ tosa_operator_map = {
TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES),
# TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute
# Handled as for TFLite for now
- TosaOp.PAD: (Op.Pad, None, pad_quant_info, TOSA_IFM_INDICES),
+ TosaOp.PAD: (Op.Pad, None, None, TOSA_IFM_INDICES),
TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.REVERSE
TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.TILE
TosaOp.TRANSPOSE: (
Op.Transpose,
+ transpose_attrs,
None,
- None,
+ # TODO: why is this IFM2 indices but then overridden to TOSA_IFM_INDICES in _reader?
+ # TOSA_IFM_IFM2_INDICES,
TOSA_IFM_IFM2_INDICES,
), # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute
# TODO TosaOp.GATHER