aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_mapping.py')
-rw-r--r--ethosu/vela/tosa_mapping.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 75f1c9c5..6710787e 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -26,19 +26,19 @@ from .tosa import ArithmeticRightShiftAttribute # noqa: F401
from .tosa import AxisAttribute # noqa: F401
from .tosa import ClampAttribute # noqa: F401
from .tosa import CondIfAttribute # noqa: F401
-from .tosa import Conv2dAttribute # noqa: F401
+from .tosa import ConvAttribute # noqa: F401
from .tosa import ConvQuantInfo # noqa: F401
from .tosa import MatMulQuantInfo # noqa: F401
from .tosa import MulAttribute # noqa: F401
from .tosa import PadQuantInfo # noqa: F401
-from .tosa import Pool2dAttribute # noqa: F401
+from .tosa import PoolAttribute # noqa: F401
from .tosa import ReluNAttribute # noqa: F401
from .tosa import RescaleAttribute # noqa: F401
from .tosa import ReshapeAttribute # noqa: F401
from .tosa import ResizeAttribute # noqa: F401
from .tosa import SliceAttribute # noqa: F401
from .tosa import TileAttribute # noqa: F401
-from .tosa import TransposeConv2dAttribute # noqa: F401
+from .tosa import TransposeConvAttribute # noqa: F401
from .tosa import UnaryQuantInfo # noqa: F401
from .tosa import WhileLoopAttribute # noqa: F401
from .tosa.DType import DType
@@ -148,10 +148,10 @@ class QuantSerializer:
is_vec = True
-pool2d_attrs = AttrSerializer("Pool2dAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec)))
-conv2d_attrs = AttrSerializer("Conv2dAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec)))
-transpose_conv2d_attrs = AttrSerializer(
- "TransposeConv2dAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
+pool_attrs = AttrSerializer("PoolAttribute", (("padding", is_vec), ("kernel", is_vec), ("stride", is_vec)))
+conv_attrs = AttrSerializer("ConvAttribute", (("padding", is_vec), ("stride", is_vec), ("dilation", is_vec)))
+transpose_conv_attrs = AttrSerializer(
+ "TransposeConvAttribute", (("outpad", is_vec), ("stride", is_vec), ("dilation", is_vec), ("out_shape", is_vec))
)
relun_attrs = AttrSerializer("ReluNAttribute", ("max_int"))
axis_attrs = AttrSerializer("AxisAttribute", ("axis",))
@@ -187,7 +187,7 @@ unsupported_tosa_operators = {
TosaOp.BITWISE_AND,
TosaOp.BITWISE_OR,
TosaOp.BITWISE_XOR,
- TosaOp.DIV,
+ TosaOp.INTDIV,
TosaOp.LOGICAL_AND,
TosaOp.LOGICAL_LEFT_SHIFT,
TosaOp.LOGICAL_RIGHT_SHIFT,
@@ -244,14 +244,14 @@ TOSA_CONCAT_INDICES = TensorIndices([1, 2], [], [])
tosa_operator_map = {
# TosaOp.UNKNOWN: (),
# TODO TosaOp.ARGMAX: (Op.ArgMax, axis_attrs, None),
- TosaOp.AVG_POOL2D: (Op.AvgPool, pool2d_attrs, unary_quant_info, TOSA_IFM_INDICES),
- TosaOp.CONV2D: (Op.Conv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.AVG_POOL2D: (Op.AvgPool, pool_attrs, unary_quant_info, TOSA_IFM_INDICES),
+ TosaOp.CONV2D: (Op.Conv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
# TODO TosaOp.CONV3D:
- TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv2d_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
+ TosaOp.DEPTHWISE_CONV2D: (Op.DepthwiseConv2DBias, conv_attrs, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
TosaOp.FULLY_CONNECTED: (Op.FullyConnected, None, conv_quant_info, TOSA_IFM_WEIGHTS_BIAS_INDICES),
# TODO TosaOp.MATMUL:
- TosaOp.MAX_POOL2D: (Op.MaxPool, pool2d_attrs, None, TOSA_IFM_INDICES),
- # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv2d_attrs, conv_quant_info)
+ TosaOp.MAX_POOL2D: (Op.MaxPool, pool_attrs, None, TOSA_IFM_INDICES),
+ # TODO TosaOp.TRANSPOSE_CONV2D: (Op.Conv2DBackpropInput, transpose_conv_attrs, conv_quant_info)
TosaOp.CLAMP: (Op.Clamp, clamp_attrs, None, TOSA_IFM_INDICES),
TosaOp.RELUN: (Op.ReluN, relun_attrs, None, TOSA_IFM_INDICES),
# TODO TosaOp.SIGMOID
@@ -261,7 +261,7 @@ tosa_operator_map = {
# TODO TosaOp.BITWISE_AND
# TODO TosaOp.BITWISE_OR
# TODO TosaOp.BITWISE_XOR
- # TODO TosaOp.DIV
+ # TODO TosaOp.INTDIV
# TODO TosaOp.LOGICAL_AND
# TODO TosaOp.LOGICAL_LEFT_SHIFT
# TODO TosaOp.LOGICAL_RIGHT_SHIFT