diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-09-01 10:39:04 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-09-08 09:02:49 +0200 |
commit | 98a3499ec73b26880c633caf9a43bfe80f9ec1ed (patch) | |
tree | 2a098625e57dcd75e7aafd1ee340f971c62ffed7 | |
parent | 515c956c9cc6d45493e45d57b822e30a7317d1ed (diff) | |
download | ethos-u-vela-98a3499ec73b26880c633caf9a43bfe80f9ec1ed.tar.gz |
MLBEDSW-2935: LUT fusing with preceding operator
Allows fusing of LUT with a preceding operator regardless of
input/output scale.
Change-Id: Ia378adbb3fe61d71299feb085f7313377e0efa39
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 11 | ||||
-rw-r--r-- | ethosu/vela/operation.py | 7 | ||||
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 6 | ||||
-rw-r--r-- | ethosu/vela/weight_compressor.py | 4 |
4 files changed, 16 insertions, 12 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index a89f8e63..b9110b8b 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -823,28 +823,21 @@ def fuse_activation_function_with_prev(op, arch): and len(ifm.ops) == 1 and len(prev_op.outputs[0].consumers()) == 1 and prev_op.attrs.get("fused_activation_function", None) is None - and ifm.is_scaling_equal(ofm) ) if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0: # TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC), # LUT currently only works correctly for elementwise ops fuse = False - if fuse and op.activation_lut is not None: - # Check if LUT can be used with prev_op - prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm() - fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm) - if prev_ifm2 is not None: - fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm) if not fuse: return op # Move the fused activation function + corresponding info to prev_op - for attr in ("fused_activation_function", "alpha", "forced_output_quantization"): + for attr in ("fused_activation_function", "forced_output_quantization"): if attr in op.attrs: prev_op.attrs[attr] = op.attrs[attr] if op.activation_lut is not None: prev_op.set_activation_lut(op.activation_lut) # Bypass op - prev_op.set_output_tensor(op.outputs[0]) + prev_op.set_output_tensor(ofm) return op diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 4b83b39b..e7fd97c4 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -200,6 +200,10 @@ input and output tensors, as well as an attribute dictionary.""" return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor + def get_ofm(self): + _, _, _, ofm = self.get_ifm_ifm2_weights_ofm() + return ofm + def is_concat_op(self): return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped") @@ -361,3 +365,6 @@ input and output tensors, as well as an attribute dictionary.""" "Conv2DBackpropInputSwitchedBias", "FullyConnectedAct", ) + + def get_output_quantization(self): + return self.attrs.get("forced_output_quantization", self.get_ofm().quantization) diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 0a356475..8f34e639 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -909,7 +909,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): if tens is None: continue - need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite") or fused_quantize + need_zero_point = ( + (faf is not None and forced_ofm_quantization is None) + or (fmf == "ConcatSliceWrite") + or fused_quantize + ) if ( ( primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "CLZ", "SHL")) diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py index 175646b8..2374cd42 100644 --- a/ethosu/vela/weight_compressor.py +++ b/ethosu/vela/weight_compressor.py @@ -416,13 +416,13 @@ def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=Fals first_consumer_op = tens.consumer_list[0] ifm_dtype = first_consumer_op.inputs[0].dtype ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32 - ofm_scale = first_consumer_op.outputs[0].quantization.scale_f32 + ofm_scale = first_consumer_op.get_output_quantization().scale_f32 weight_scales = first_consumer_op.inputs[1].quantization.scale_f32 # biases can have multiple consumers for rnn cells. if so, then check that they are all the same for op in tens.consumer_list[1:]: assert ifm_scale == op.inputs[0].quantization.scale_f32 - assert ofm_scale == op.outputs[0].quantization.scale_f32 + assert ofm_scale == op.get_output_quantization().scale_f32 assert weight_scales == op.inputs[1].quantization.scale_f32 if not hasattr(weight_scales, "__iter__"): |