diff options
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index f55d1ce5..1f072424 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -18,7 +18,13 @@ # Description: # Functions used to write to a TensorFlow Lite format file. Supports adding in file identifiers. +import numpy as np import flatbuffers +from flatbuffers.builder import UOffsetTFlags + +# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here: +import flatbuffers.number_types as N +from flatbuffers import encode from .tflite import Tensor from .tflite import QuantizationParameters @@ -28,22 +34,14 @@ from .tflite import OperatorCode from .tflite import Operator from .tflite import Buffer from .tflite import Metadata - -import numpy as np - from .tflite_mapping import datatype_inv_map, builtin_operator_inv_map, custom_prefix, BuiltinOperator from .nn_graph import PassPlacement from .tensor import TensorPurpose, MemArea -from flatbuffers.builder import UOffsetTFlags tflite_version = 3 tflite_file_identifier = "TFL" + str(tflite_version) -import flatbuffers.number_types as N -from flatbuffers import encode - - def FinishWithFileIdentifier(self, rootTable, fid): if fid is None or len(fid) != 4: raise Exception("fid must be 4 chars") @@ -163,8 +161,8 @@ class TFLiteSerialiser: tf_code, opt_serializer = builtin_operator_inv_map[code] except KeyError: print( - "Warning: Writing operation %s, which does not have a direct TensorFlow Lite mapping, as a custom operation" - % (code,) + "Warning: Writing operation %s, which does not have a direct TensorFlow Lite mapping," + "as a custom operation" % (code,) ) tf_code, opt_serializer = builtin_operator_inv_map[custom_prefix] |