diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r-- | ethosu/vela/graph_optimiser_util.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py index dafd2849..d2d3d833 100644 --- a/ethosu/vela/graph_optimiser_util.py +++ b/ethosu/vela/graph_optimiser_util.py @@ -19,6 +19,7 @@ from typing import Tuple import numpy as np +from . import lut from .data_type import DataType from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError @@ -26,6 +27,8 @@ from .errors import VelaError from .operation import Op from .operation_util import create_avgpool_nop from .shape4d import Shape4D +from .tensor import create_const_tensor +from .tensor import QuantizationParameters memory_only_ops = ( Op.Reshape, @@ -320,3 +323,31 @@ def convert_depthwise_to_conv(op, arch, nng): ) DebugDatabase.add_optimised(op, op) return op + + +def convert_to_lut(op, lut_values, lut_name): + # Rewrite the operation by Add with scalar 0 + LUT activation + ifm = op.inputs[0] + if ifm is None: + return op + assert ifm.dtype.size_in_bytes() == 1 + op.type = Op.Add + op.name = op.name + "_lut_" + lut_name + # Mark as no-op to enable potential fusing optimizations + op.attrs["is_nop"] = True + # Create an input tensor containing scalar zero + quantization = QuantizationParameters(0.0, 255.0) + quantization.scale_f32 = ifm.quantization.scale_f32 + quantization.zero_point = 0 + tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization) + op.add_input_tensor(tens) + op.ifm_shapes.append(Shape4D(tens.shape)) # TODO no shape? + + # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale), + # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions + # should be the same as the IFM + op.forced_output_quantization = ifm.quantization + lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8) + op.set_activation_lut(lut_tensor) + op.set_ifm_ofm_shapes() + return op |