aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/lut.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-04-24 13:35:40 +0200
committerRickard Bolin <rickard.bolin@arm.com>2023-05-02 11:03:37 +0000
commitce5027328d2330d33bdfc5d5b016d171e4f8a2fc (patch)
tree16528c451c430d14d3221bfe968e94a54c6af33a /ethosu/vela/lut.py
parent463f74b2ed65753811ddf7460e760aaf983ef5bb (diff)
downloadethos-u-vela-ce5027328d2330d33bdfc5d5b016d171e4f8a2fc.tar.gz
MLBEDSW-2082: Add Exp support
- Added int8 and int16 Exp support, implemented as LUT. - Added generic 8bit and 16bit LUT table functions following the implementation in the latest reference. If new ops are added by the reference, they can easily be implemented in Vela using the generic functions. - Moved convert_to_lut to lut.py to have all LUT related code in one file. - Updated SUPPORTED_OPS.md Change-Id: I388e76ea4b39162313599a5341cfb9bad71a782c Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/lut.py')
-rw-r--r--ethosu/vela/lut.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index d0ac9706..c8fb7bc0 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -21,10 +21,15 @@ import uuid
import numpy as np
from . import numeric_util
+from .data_type import DataType
+from .debug_database import DebugDatabase
from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
+from .numeric_util import round_away_zero
+from .operation import Op
from .tensor import create_const_tensor
from .tensor import create_equivalence_id
+from .tensor import QuantizationParameters
from .tensor import TensorPurpose
@@ -88,6 +93,8 @@ def create_lut_tensor(name, values, dtype):
# address in constant memory, and unnecessary DMA operations can be avoided.
sz = len(values)
assert sz in (256, 512)
+ # int16 lut uses uint32 lut with base + slope
+ dtype = DataType.uint32 if dtype == DataType.int16 else dtype
tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
tens.equivalence_id = create_equivalence_id(tuple(values))
return tens
@@ -128,3 +135,110 @@ def optimize_high_level_cmd_stream(sg, arch):
lut_state = lut_state.put(lut_tens)
cmd_stream.append(cmd)
sg.high_level_command_stream = cmd_stream
+
+
+def convert_to_lut(op, lut_values, lut_name):
+ # Rewrite the operation by Add with scalar 0 + LUT activation
+ ifm = op.ifm
+ ofm = op.ofm
+ if ifm is None:
+ return op
+ assert ifm.dtype in (DataType.int8, DataType.uint8, DataType.int16)
+ op.type = Op.Add
+ op.name = f"{op.name}_lut_{lut_name}"
+ # Mark as no-op to enable potential fusing optimizations
+ op.attrs["is_nop"] = True
+ # Create an input tensor containing scalar zero
+ _max = 65536.0 if ifm.dtype == DataType.int16 else 255.0
+ quantization = QuantizationParameters(0.0, _max)
+ quantization.scale_f32 = ifm.quantization.scale_f32
+ quantization.zero_point = 0
+ tens = create_const_tensor(ifm.name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
+ op.add_input_tensor(tens)
+
+ # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+ # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+ # should be the same as the IFM
+ op.forced_output_quantization = ifm.quantization
+
+ # the lut tensor datatype needs to match both; the ofm datatype, because these are the values output; and the
+ # datatype used to generate the lut values (which is probably the ifm datatype), because we want to avoid any
+ # potential overflow errors in create_lut_tensor() caused by converting Python int (which could represent a uint)
+ # to NumPy int. this can be guaranteed by checking that the ifm and ofm datatypes are the same
+ assert ifm.dtype == ofm.dtype
+ lut_tensor = create_lut_tensor(op.name + "_values", lut_values, ofm.dtype)
+ op.set_activation_lut(lut_tensor)
+ op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, op)
+ return op
+
+
+def create_lut_8bit_op(op, lut_fn, fn_name):
+ ifm_scale = op.ifm.quantization.scale_f32
+ ofm_scale = op.ofm.quantization.scale_f32
+ zp_in = op.ifm.quantization.zero_point
+ zp_out = op.ofm.quantization.zero_point
+
+ values = []
+ ix = range(256) if op.ifm.dtype == DataType.uint8 else range(-128, 128)
+ quantized_min = min(ix)
+ quantized_max = max(ix)
+ for x in ix:
+ x_real = ifm_scale * (x - zp_in)
+ y_real = lut_fn(x_real)
+ lut_result = round_away_zero(y_real / ofm_scale) + zp_out
+ lut_result = min(quantized_max, max(quantized_min, lut_result))
+ values.append(lut_result)
+
+ return convert_to_lut(op, values, fn_name)
+
+
+def create_lut_int16_op(op, lut_fn, fn_name):
+ ifm_scale = op.ifm.quantization.scale_f32
+ ofm_scale = op.ofm.quantization.scale_f32
+ zp_in = op.ifm.quantization.zero_point
+ zp_out = op.ofm.quantization.zero_point
+
+ input_min = ifm_scale * (np.iinfo(np.int16).min - zp_in)
+ input_max = ifm_scale * (np.iinfo(np.int16).max - zp_in)
+ output_min = ofm_scale * (np.iinfo(np.int16).min - zp_out)
+ output_max = ofm_scale * (np.iinfo(np.int16).max - zp_out)
+
+ # Create 16bit lut following the reference
+ nbr_steps = 512
+ step = (input_max - input_min) / nbr_steps
+ half_step = step / 2
+ output_scaling_inv = (np.iinfo(np.int16).max - np.iinfo(np.int16).min + 1) / (output_max - output_min)
+
+ table_min = np.iinfo(np.int16).min
+ table_max = np.iinfo(np.int16).max
+
+ values = []
+ for i in range(nbr_steps):
+ val = lut_fn(input_min + i * step)
+ val_midpoint = lut_fn(input_min + i * step + half_step)
+ val_next = lut_fn(input_min + (i + 1) * step)
+
+ sample_val = round_away_zero(val * output_scaling_inv)
+ midpoint_interp_val = round_away_zero(
+ (val_next * output_scaling_inv + round_away_zero(val * output_scaling_inv)) / 2
+ )
+ midpoint_val = round_away_zero(val_midpoint * output_scaling_inv)
+ midpoint_err = midpoint_interp_val - midpoint_val
+ bias = round_away_zero(midpoint_err / 2)
+
+ lut_result = min(max(sample_val - bias, table_min), table_max)
+ values.append(lut_result)
+
+ val = round_away_zero(lut_fn(input_max) * output_scaling_inv)
+ lut_result = min(max(val, table_min), table_max)
+ values.append(lut_result)
+
+ # Convert to hardware 16bit lut with base and slope
+ lut = [0] * nbr_steps
+ for i in range(nbr_steps):
+ slope = (int(values[i + 1]) - int(values[i])) << 16
+ base = int(values[i])
+ lut[i] = slope + base
+
+ return convert_to_lut(op, lut, fn_name)