aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/tensor.py1
-rw-r--r--ethosu/vela/tflite_writer.py21
2 files changed, 7 insertions, 15 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 9ba6ab77..6ba331c4 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -506,6 +506,7 @@ class Tensor:
res.name = res.name + suffix
res.ops = []
res.consumer_list = []
+ res.src_tensor = self
return res
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 625cf7cc..c8250c6e 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -90,8 +90,6 @@ class TFLiteSerialiser:
self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput)
- self.tensors_to_reshape = {}
-
self.subgraphs_to_write = [sg for sg in self.nng.subgraphs if sg.placement == PassPlacement.Cpu]
all_ops = []
@@ -102,14 +100,12 @@ class TFLiteSerialiser:
# swap from nng input indexing to TensorFlow Lite input indexing
self.align_nng_inputs_to_tflite(op)
all_ops.append(op)
- if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
- # If values are None op has non-constant weights
- if op.inputs[1].values is not None:
- self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
- if op.type == Op.FullyConnected:
- # If values are None op has non-constant weights
- if op.inputs[1].values is not None:
- self.tensors_to_reshape[op.inputs[1]] = (1, 0)
+ if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type == Op.FullyConnected:
+ # Op is run on CPU, make sure original tensor are written back
+ # instead of the cloned/reshaped (see tflite_reader)
+ for idx, inp in enumerate(op.inputs):
+ if inp is not None and inp.src_tensor is not None:
+ op.inputs[idx] = inp.src_tensor
# list of tuple(Op, string); the custom code is only used for 3rd party custom operators
self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
@@ -259,11 +255,6 @@ class TFLiteSerialiser:
tens_shape = tens.original_shape
values = tens.values
- if tens in self.tensors_to_reshape:
- reorder = self.tensors_to_reshape[tens]
- tens_shape = [tens_shape[idx] for idx in reorder]
- values = values.transpose(reorder)
-
buf_id = self.buffer_map[tens]
self.buffers_to_write[buf_id] = None if values is None else values.flatten().view(np.uint8)