diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-08-10 13:56:34 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-08-11 12:27:37 +0200 |
commit | d15866c06c88b9ec2e6313cc19f89ea65b528f8a (patch) | |
tree | 26dd354d0e49038ae0314445b60cece6e051701b /ethosu | |
parent | ebb3b6fd78c54d6cf95b673aad3d868b798f91c7 (diff) | |
download | ethos-u-vela-d15866c06c88b9ec2e6313cc19f89ea65b528f8a.tar.gz |
MLBEDSW-4838 TOSA const data input changes
Adoptions related to changes for constant data
in TOSA.
Constant data not longer stored in .npy files, but
within the .tosa-file.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ia1148c2f8b783b3926a1ee0b9ad0a3aeff9d22f5
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/tosa/TosaTensor.py | 24 | ||||
-rw-r--r-- | ethosu/vela/tosa_mapping.py | 11 | ||||
-rw-r--r-- | ethosu/vela/tosa_reader.py | 31 |
3 files changed, 48 insertions, 18 deletions
diff --git a/ethosu/vela/tosa/TosaTensor.py b/ethosu/vela/tosa/TosaTensor.py index 2a397db2..01ae7ece 100644 --- a/ethosu/vela/tosa/TosaTensor.py +++ b/ethosu/vela/tosa/TosaTensor.py @@ -55,16 +55,32 @@ class TosaTensor(object): return 0 # TosaTensor - def NpyFilename(self): + def Data(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: - return self._tab.String(o + self._tab.Pos) - return None + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # TosaTensor + def DataAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint8Flags, o) + return 0 + + # TosaTensor + def DataLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.VectorLen(o) + return 0 def TosaTensorStart(builder): builder.StartObject(4) def TosaTensorAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def TosaTensorAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) def TosaTensorStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) def TosaTensorAddType(builder, type): builder.PrependUint32Slot(2, type, 0) -def TosaTensorAddNpyFilename(builder, npyFilename): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(npyFilename), 0) +def TosaTensorAddData(builder, data): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0) +def TosaTensorStartDataVector(builder, numElems): return builder.StartVector(1, numElems, 1) def TosaTensorEnd(builder): return builder.EndObject() diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py index 312ac92e..75ca43ef 100644 --- a/ethosu/vela/tosa_mapping.py +++ b/ethosu/vela/tosa_mapping.py @@ -17,6 +17,8 @@ # TOSA mapping functions used by reader. # Contains a mapping from the various TOSA enums and options structs, generated by the FlatBuffer code # generator, to Vela's internal format. +import numpy as np + from .data_type import DataType from .operation import Op from .operation import TensorIndices @@ -54,6 +56,15 @@ datatype_map = { DType.FLOAT: DataType.float32, } +datatype_map_numpy = { + DType.BOOL: np.bool, + DType.UINT8: np.uint8, + DType.INT8: np.int8, + DType.INT16: np.int16, + DType.INT32: np.int32, + DType.FLOAT: np.float32, +} + # TODO duplicate of tflite_mapping def underscore_to_camel_case(s): diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 364d9a63..dfed035d 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -34,18 +34,19 @@ from .tensor import Tensor from .tflite_mapping import DataType from .tosa.TosaGraph import TosaGraph as TG from .tosa_mapping import datatype_map +from .tosa_mapping import datatype_map_numpy from .tosa_mapping import tosa_operator_map from .tosa_mapping import unsupported_tosa_operators class TosaSubgraph: - def __init__(self, file_path, graph, block): + def __init__(self, graph, block): self.graph = graph self.name = decode_str(block.Name()) self.tensors = [] for idx in range(block.TensorsLength()): - self.tensors.append(self.parse_tensor(block.Tensors(idx), file_path)) + self.tensors.append(self.parse_tensor(block.Tensors(idx))) for idx in range(block.OperatorsLength()): self.parse_operator(idx, block.Operators(idx)) @@ -166,7 +167,7 @@ class TosaSubgraph: if "b_zp" in quant_info: self.set_tensor_zp(op.ifm2, quant_info["b_zp"]) - def parse_tensor(self, tens_data, file_path): + def parse_tensor(self, tens_data): name = decode_str(tens_data.Name()) np_shape = tens_data.ShapeAsNumpy() shape = list(np_shape) if type(np_shape) is np.ndarray else [] @@ -182,19 +183,22 @@ class TosaSubgraph: if dtype == DataType.uint8: tens.quantization.quant_min = 0 tens.quantization.quant_max = (1 << dtype.bits) - 1 - elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int64): + elif dtype in (DataType.int8, DataType.int16, DataType.int32, DataType.int48): tens.quantization.quant_min = -(1 << (dtype.bits - 1)) tens.quantization.quant_max = (1 << (dtype.bits - 1)) - 1 tens.values = None - if tens_data.NpyFilename() is not None: - try: - fname = decode_str(tens_data.NpyFilename()) - tens.values = np.load(os.path.join(file_path, fname)) - assert list(tens.values.shape) == tens.shape - except (struct.error, TypeError, RuntimeError) as e: - print(f'Error: Invalid npy file. Got "{e}" ') - sys.exit(1) + + data_length = tens_data.DataLength() + if data_length != 0: + data_as_numpy = tens_data.DataAsNumpy() + if tens_dtype in datatype_map_numpy: + np_dtype = datatype_map_numpy[tens_dtype] + tens.values = np.array(data_as_numpy.view(np_dtype).reshape(shape)) + else: + # int48 is only expected as an accumulated data/output format, int4 not supported + print(f"Error: unsupported/unexpected Tensor type {dtype}, with data") + assert False return tens @@ -227,11 +231,10 @@ class TosaGraph: self.check_version(tosa_graph) parsing_step = "parsing blocks length" - file_path = os.path.dirname(filename) self.subgraphs = [] for b_idx in range(tosa_graph.BlocksLength()): parsing_step = f"parsing block {b_idx}" - self.subgraphs.append(TosaSubgraph(file_path, self, tosa_graph.Blocks(b_idx))) + self.subgraphs.append(TosaSubgraph(self, tosa_graph.Blocks(b_idx))) self.nng = Graph(self.name, self.batch_size) for tosa_sg in self.subgraphs: |