aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/graph_optimiser.py11
-rw-r--r--ethosu/vela/operation.py7
-rw-r--r--ethosu/vela/register_command_stream_generator.py6
-rw-r--r--ethosu/vela/weight_compressor.py4
4 files changed, 16 insertions, 12 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index a89f8e63..b9110b8b 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -823,28 +823,21 @@ def fuse_activation_function_with_prev(op, arch):
and len(ifm.ops) == 1
and len(prev_op.outputs[0].consumers()) == 1
and prev_op.attrs.get("fused_activation_function", None) is None
- and ifm.is_scaling_equal(ofm)
)
if op.activation_lut is not None and arch.shram_reserved_unused_banks == 0:
# TODO: if SHRAM LUT space is shared with SHRAM ACC (32, 64 MAC),
# LUT currently only works correctly for elementwise ops
fuse = False
- if fuse and op.activation_lut is not None:
- # Check if LUT can be used with prev_op
- prev_ifm, prev_ifm2, _, _ = prev_op.get_ifm_ifm2_weights_ofm()
- fuse = prev_ifm is not None and prev_ifm.quantization is not None and prev_ifm.is_scaling_equal(ifm)
- if prev_ifm2 is not None:
- fuse = fuse and prev_ifm2.quantization is not None and prev_ifm2.is_scaling_equal(ifm)
if not fuse:
return op
# Move the fused activation function + corresponding info to prev_op
- for attr in ("fused_activation_function", "alpha", "forced_output_quantization"):
+ for attr in ("fused_activation_function", "forced_output_quantization"):
if attr in op.attrs:
prev_op.attrs[attr] = op.attrs[attr]
if op.activation_lut is not None:
prev_op.set_activation_lut(op.activation_lut)
# Bypass op
- prev_op.set_output_tensor(op.outputs[0])
+ prev_op.set_output_tensor(ofm)
return op
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 4b83b39b..e7fd97c4 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -200,6 +200,10 @@ input and output tensors, as well as an attribute dictionary."""
return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
+ def get_ofm(self):
+ _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
+ return ofm
+
def is_concat_op(self):
return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
@@ -361,3 +365,6 @@ input and output tensors, as well as an attribute dictionary."""
"Conv2DBackpropInputSwitchedBias",
"FullyConnectedAct",
)
+
+ def get_output_quantization(self):
+ return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 0a356475..8f34e639 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -909,7 +909,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
if tens is None:
continue
- need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite") or fused_quantize
+ need_zero_point = (
+ (faf is not None and forced_ofm_quantization is None)
+ or (fmf == "ConcatSliceWrite")
+ or fused_quantize
+ )
if (
(
primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "CLZ", "SHL"))
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 175646b8..2374cd42 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -416,13 +416,13 @@ def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=Fals
first_consumer_op = tens.consumer_list[0]
ifm_dtype = first_consumer_op.inputs[0].dtype
ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
- ofm_scale = first_consumer_op.outputs[0].quantization.scale_f32
+ ofm_scale = first_consumer_op.get_output_quantization().scale_f32
weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
# biases can have multiple consumers for rnn cells. if so, then check that they are all the same
for op in tens.consumer_list[1:]:
assert ifm_scale == op.inputs[0].quantization.scale_f32
- assert ofm_scale == op.outputs[0].quantization.scale_f32
+ assert ofm_scale == op.get_output_quantization().scale_f32
assert weight_scales == op.inputs[1].quantization.scale_f32
if not hasattr(weight_scales, "__iter__"):