aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/tflite_reader.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 18b61e75..061f3626 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -209,7 +209,12 @@ class TFLiteSubgraph:
# The original value is cached above in channel_multiplier
op.attrs["depth_multiplier"] = op.weights.shape[2] // op.ifm.shape[-1]
- faf = op.attrs.pop("fused_activation_function", None)
+ # The fused_activation_function attribute needs to be retained so that the
+ # tflite_writer can correctly pass through operators that run on the CPU.
+ # This is because the operator activation attribute is later converted to an
+ # NpuActivation which treats None and ReLU the same, thereby making it difficult
+ # for the tflite_writer to recover the original activation function.
+ faf = op.attrs.get("fused_activation_function", None)
if faf is not None:
op.activation = create_activation_function(faf)
if custom_code is not None: