aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 44ce711..d4e24a2 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -105,9 +105,11 @@ class TFLiteSerialiser:
if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type == Op.FullyConnected:
# Op is run on CPU, make sure the original weight and bias tensors are written back
# instead of the cloned/reshaped (see tflite_reader)
- for idx, inp in enumerate(op.inputs):
- if inp != op.ifm and inp is not None and inp.src_tensor is not None:
- op.inputs[idx] = inp.src_tensor
+ # Do nothing when values are None (dynamic weights)
+ if op.inputs[1].values is not None:
+ for idx, inp in enumerate(op.inputs):
+ if inp != op.ifm and inp is not None and inp.src_tensor is not None:
+ op.inputs[idx] = inp.src_tensor
# list of tuple(Op, string, op.version); the custom code is only used for 3rd party custom operators
self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", ""), op.version) for op in all_ops))