aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index dafd2849..d2d3d833 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -19,6 +19,7 @@ from typing import Tuple
import numpy as np
+from . import lut
from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
@@ -26,6 +27,8 @@ from .errors import VelaError
from .operation import Op
from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
+from .tensor import create_const_tensor
+from .tensor import QuantizationParameters
memory_only_ops = (
Op.Reshape,
@@ -320,3 +323,31 @@ def convert_depthwise_to_conv(op, arch, nng):
)
DebugDatabase.add_optimised(op, op)
return op
+
+
+def convert_to_lut(op, lut_values, lut_name):
+ # Rewrite the operation by Add with scalar 0 + LUT activation
+ ifm = op.inputs[0]
+ if ifm is None:
+ return op
+ assert ifm.dtype.size_in_bytes() == 1
+ op.type = Op.Add
+ op.name = op.name + "_lut_" + lut_name
+ # Mark as no-op to enable potential fusing optimizations
+ op.attrs["is_nop"] = True
+ # Create an input tensor containing scalar zero
+ quantization = QuantizationParameters(0.0, 255.0)
+ quantization.scale_f32 = ifm.quantization.scale_f32
+ quantization.zero_point = 0
+ tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+ op.add_input_tensor(tens)
+ op.ifm_shapes.append(Shape4D(tens.shape)) # TODO no shape?
+
+ # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+ # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+ # should be the same as the IFM
+ op.forced_output_quantization = ifm.quantization
+ lut_tensor = lut.create_lut_tensor(op.name + "_values", lut_values, DataType.int8)
+ op.set_activation_lut(lut_tensor)
+ op.set_ifm_ofm_shapes()
+ return op