aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-08-13 11:47:36 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-08-24 11:39:18 +0200
commitb9fc33c194036973273604d5fd7af9e814133238 (patch)
tree6073239bcd0cbb53838d363cc1d380f2db05ce6b
parente9194df0b003db774fc73d841945f213b714016b (diff)
downloadethos-u-vela-b9fc33c194036973273604d5fd7af9e814133238.tar.gz
MLBEDSW-2688: LeakyRelu rewrite to LUT or MUL/MAX
Replaces LeakyRelu operations with LUT activation function when possible, else to a combination of multiplication/maximization. Signed-off-by: Louis Verhaard <louis.verhaard@arm.com> Change-Id: I3d2eb2dba7145997c3cc711d0ef18ab355fbb416
-rw-r--r--ethosu/vela/graph_optimiser.py159
-rw-r--r--ethosu/vela/lut.py16
-rw-r--r--ethosu/vela/mark_tensors.py2
-rw-r--r--ethosu/vela/tensor.py3
4 files changed, 176 insertions, 4 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 78c0dcd4..8d920d83 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -20,6 +20,7 @@ import math
import numpy as np
+from . import lut
from . import rewrite_graph
from .data_type import DataType
from .errors import UnsupportedFeatureError
@@ -585,6 +586,12 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
# make sure the Mul doesn't have a faf
if mul.attrs["fused_activation_function"]:
return op
+ ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+ if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
+ return op
+ if not ifm.is_scaling_equal(ofm):
+ # rewrite to LeakyRelu currently only makes sense if the quantization is identical
+ return op
# finds the branched input that goes to both the Max and the Mul
shared = set(op.inputs) & set(mul.inputs)
@@ -599,6 +606,8 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
# check that it is a constant
if const.type != "Const":
return op
+ # Remove the Mul from the shared input's consumers
+ shared_in.consumer_list.remove(mul)
else:
return op
@@ -618,6 +627,147 @@ def convert_mul_max_to_abs_or_lrelu(op, arch):
return op
+def convert_lrelu_to_mul_max(op, arch):
+ # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
+ # (the opposite of convert_mul_max_to_abs_or_lrelu)
+ ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+
+ # Add multiplication with alpha
+ mul_alpha = Operation("MulAct", op.name + "_mul_alpha")
+ mul_alpha.add_input_tensor(ifm)
+ # Create const tensor containing alpha as scalar
+ alpha = op.attrs["alpha"]
+ quantization = ifm.quantization.clone()
+ quantization.min = 0
+ quantization.max = alpha * (quantization.quant_max - quantization.quant_min)
+ quantization.scale_f32 = alpha
+ quantization.zero_point = 0
+ alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [], ifm.dtype, [1], np.int8, quantization=quantization)
+ mul_alpha.add_input_tensor(alpha_tens)
+ fm_alpha = ofm.clone(op.name + "_alpha")
+ mul_alpha.set_output_tensor(fm_alpha)
+
+ if ifm.is_scaling_equal(ofm):
+ # No identity multiplication is needed
+ fm_id = ifm
+ else:
+ # Add multiplication with identity
+ mul_identity = Operation("MulAct", op.name + "_mul_identity")
+ mul_identity.add_input_tensor(ifm)
+ # Create const tensor containing identity as scalar
+ quantization = ifm.quantization.clone()
+ quantization.min = 0
+ quantization.max = quantization.quant_max - quantization.quant_min
+ quantization.scale_f32 = 1
+ quantization.zero_point = 0
+ identity_tens = create_const_tensor(
+ op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
+ )
+ mul_identity.add_input_tensor(identity_tens)
+ fm_id = ofm.clone(op.name + "_id")
+ mul_identity.set_output_tensor(fm_id)
+
+ # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
+ op.type = "Maximum"
+ op.name = op.name.replace("LeakyRelu", "Maximum")
+ op.inputs = []
+ ifm.consumer_list.remove(op)
+ op.add_input_tensor(fm_alpha)
+ op.add_input_tensor(fm_id)
+ return op
+
+
+def convert_lrelu_to_lut(op, arch):
+ ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+ # Rewrite LeakyRelu by Add with scalar 0 + LUT activation
+ op.type = "AddAct"
+ op.name = op.name + "_add"
+ op.attrs.update({"npu_block_type": NpuBlockType.ElementWise})
+ # Mark as no-op to enable potential fusing optimizations
+ op.attrs["is_nop"] = True
+ # Create an input tensor containing scalar zero
+ quantization = QuantizationParameters(0.0, 255.0)
+ quantization.scale_f32 = 1.0
+ quantization.zero_point = 0
+ tens = create_const_tensor(op.inputs[0].name + "_add", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+ op.add_input_tensor(tens)
+ alpha = op.attrs["alpha"]
+ zp = ofm.quantization.zero_point
+ # Generate the LUT
+ if ifm.dtype.size_in_bytes() == 1:
+ dtype = DataType.int8
+ ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+ values = [int(x) if x >= zp else int(round(zp - alpha * (zp - x))) for x in ix]
+ else:
+ # int16
+ dtype = DataType.int32
+ values = []
+ for ix in range(512):
+ x = (ix - 256) * 128
+ if x >= zp:
+ base = x
+ slope = 128
+ else:
+ base = int(round(zp - alpha * (zp - x)))
+ next_base = int(round(zp - alpha * (zp - (x + 127))))
+ slope = int(round(128 * (next_base - base) / 127))
+ value = ((slope << 16) & 0xFFFF0000) + (base & 0xFFFF)
+ values.append(value)
+ lut_tensor = lut.create_lut_tensor(op.name + "_lut", values, dtype)
+ op.set_activation_lut(lut_tensor)
+ return op
+
+
+def convert_lrelu(op, arch):
+ # Converts LeakyRelu to a LUT based solution if possible, otherwise a mul + max
+ if op.type != "LeakyRelu":
+ return op
+ ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+ use_lut = (ifm.is_scaling_equal(ofm)) and (ifm.dtype == ofm.dtype) and ifm.dtype in (DataType.uint8, DataType.int8)
+ if use_lut:
+ return convert_lrelu_to_lut(op, arch)
+ return convert_lrelu_to_mul_max(op, arch)
+
+
+def fuse_activation_function_with_prev(op, arch):
+ # if op is a no-op: attempts to move the activation function to the preceding op
+ if not op.attrs.get("is_nop", False) or op.attrs.get("fused_activation_function", None) is None:
+ return op
+ ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
+ # finds the input(s) to the operation
+ prev_op = ifm.ops[0]
+ # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
+ fuse = (
+ prev_op.run_on_npu
+ and prev_op.attrs["npu_block_type"] != NpuBlockType.Default
+ and len(ifm.ops) == 1
+ and len(prev_op.outputs[0].consumers()) == 1
+ and prev_op.attrs.get("fused_activation_function", None) is None
+ and ifm.is_scaling_equal(ofm)
+ )
+ if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
+ # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
+ # LUT currently only works correctly for elementwise ops
+ fuse = False
+ if fuse and op.activation_lut is not None:
+ # Check if LUT can be used with prev_op
+ prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm()
+ fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm)
+ if prev_ifm2 is not None:
+ fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm)
+ if not fuse:
+ return op
+ # Move the fused activation function + corresponding info to prev_op
+ for attr in ("fused_activation_function", "alpha"):
+ if attr in op.attrs:
+ prev_op.attrs[attr] = op.attrs[attr]
+ if op.activation_lut is not None:
+ prev_op.set_activation_lut(op.activation_lut)
+ # Bypass op
+ prev_op.set_output_tensor(op.outputs[0])
+ return op
+
+
def add_attrs_to_resizebilinear(op, arch):
if op.type == "ResizeBilinear" and op.run_on_npu:
input_tensor = op.inputs[0]
@@ -679,7 +829,8 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
reorder_depthwise_weights,
fixup_resizebilinear,
add_bias_tensor,
- # convert_mul_max_to_abs_or_lrelu # TODO: enable optimisation once quantisation issues are resolved
+ convert_mul_max_to_abs_or_lrelu,
+ convert_lrelu,
]
for idx, sg in enumerate(nng.subgraphs):
@@ -689,8 +840,10 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
)
for idx, sg in enumerate(nng.subgraphs):
- # remove passthrough tensors
- nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [remove_passthrough_tensor], [])
+ # remove passthrough tensors and attempt further optimizations
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev]
+ )
if verbose_graph:
nng.print_graph()
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 39101fac..0e8dcc95 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -18,8 +18,11 @@
import uuid
from functools import lru_cache
+import numpy as np
+
from . import numeric_util
from .high_level_command_stream import CommandType
+from .tensor import create_const_tensor
from .tensor import TensorPurpose
@@ -85,6 +88,19 @@ def get_lut_index(arch, lut_tensor):
return slot
+def create_lut_tensor(name, values, dtype):
+ # Creates constant LUT tensor with the given values as lookup table.
+ # The tensor's equivalence_id is based on these values, so if multiple
+ # LUT tensors are created with identical values, they will get the same
+ # address in constant memory, and unnecessary DMA operations can be avoided.
+ sz = len(values)
+ assert sz in (256, 512)
+ ntype = np.uint8 if dtype.size_in_bytes() == 1 else np.uint32
+ tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, ntype, TensorPurpose.LUT)
+ tens.equivalence_id = create_equivalence_id(tuple(values))
+ return tens
+
+
def optimize_high_level_cmd_stream(sg, arch):
# - Allocates SHRAM address/lut index to LUT tensors
# - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 40ce467b..03ab83fe 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -284,7 +284,7 @@ def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
)
for idx, tens in enumerate(op.inputs):
- purpose = input_purpose(op, idx)
+ purpose = input_purpose(op, idx) if tens.purpose == TensorPurpose.Unknown else tens.purpose
mark_tensor_helper(tens, purpose)
if op.type == "Reshape":
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 5fdea979..f0e7ea44 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -728,6 +728,9 @@ class Tensor:
return True
return False
+ def is_scaling_equal(self, tens):
+ return self.quantization.is_scaling_equal(tens.quantization)
+
def equivalent(self, tens):
return self.equivalence_id == tens.equivalence_id