diff options
-rw-r--r-- | ethosu/vela/tensor.py | 1 | ||||
-rw-r--r-- | ethosu/vela/tflite_writer.py | 21 |
2 files changed, 7 insertions, 15 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 9ba6ab77..6ba331c4 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -506,6 +506,7 @@ class Tensor: res.name = res.name + suffix res.ops = [] res.consumer_list = [] + res.src_tensor = self return res diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 625cf7cc..c8250c6e 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -90,8 +90,6 @@ class TFLiteSerialiser: self.ops_to_ignore = (Op.Const, Op.Placeholder, Op.SubgraphInput) - self.tensors_to_reshape = {} - self.subgraphs_to_write = [sg for sg in self.nng.subgraphs if sg.placement == PassPlacement.Cpu] all_ops = [] @@ -102,14 +100,12 @@ class TFLiteSerialiser: # swap from nng input indexing to TensorFlow Lite input indexing self.align_nng_inputs_to_tflite(op) all_ops.append(op) - if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op(): - # If values are None op has non-constant weights - if op.inputs[1].values is not None: - self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2) - if op.type == Op.FullyConnected: - # If values are None op has non-constant weights - if op.inputs[1].values is not None: - self.tensors_to_reshape[op.inputs[1]] = (1, 0) + if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type == Op.FullyConnected: + # Op is run on CPU, make sure original tensor are written back + # instead of the cloned/reshaped (see tflite_reader) + for idx, inp in enumerate(op.inputs): + if inp is not None and inp.src_tensor is not None: + op.inputs[idx] = inp.src_tensor # list of tuple(Op, string); the custom code is only used for 3rd party custom operators self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops)) @@ -259,11 +255,6 @@ class TFLiteSerialiser: tens_shape = tens.original_shape values = tens.values - if tens in self.tensors_to_reshape: - reorder = self.tensors_to_reshape[tens] - tens_shape = [tens_shape[idx] for idx in reorder] - values = values.transpose(reorder) - buf_id = self.buffer_map[tens] self.buffers_to_write[buf_id] = None if values is None else values.flatten().view(np.uint8) |