diff options
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index dfed035d..eb317169 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -30,6 +30,7 @@ from .reader_util import clone_and_reshape_tensor from .reader_util import decode_str from .reader_util import fixup_tensors from .tensor import QuantizationParameters +from .tensor import shape_num_elements from .tensor import Tensor from .tflite_mapping import DataType from .tosa.TosaGraph import TosaGraph as TG @@ -135,6 +136,22 @@ class TosaSubgraph: if attr_serializer is not None: op.attrs = attr_serializer.deserialize(op_data) + if "padding" in op.attrs: + padding = op.attrs["padding"] # [top, bottom, left, right] + op.attrs["explicit_padding"] = ( + padding[0], + padding[2], + padding[1], + padding[3], + ) # [top, left, bottom, right] + if "stride" in op.attrs: + stride = op.attrs["stride"] + if len(stride) == 2: + op.attrs["strides"] = (1, stride[0], stride[1], 1) + else: + # TODO CONV3D more to be done.... + print("Unsupported kernel dimensions: ", len(stride)) + assert False if "dilation" in op.attrs: dilation = op.attrs["dilation"] if len(dilation) == 2: @@ -160,7 +177,7 @@ class TosaSubgraph: self.set_tensor_zp(op.ifm, quant_info["input_zp"]) if "weight_zp" in quant_info: self.set_tensor_zp(op.weights, quant_info["weight_zp"]) - if "ouput_zp" in quant_info: + if "output_zp" in quant_info: self.set_tensor_zp(op.ofm, quant_info["output_zp"]) if "a_zp" in quant_info: self.set_tensor_zp(op.ifm, quant_info["a_zp"]) @@ -194,7 +211,12 @@ class TosaSubgraph: data_as_numpy = tens_data.DataAsNumpy() if tens_dtype in datatype_map_numpy: np_dtype = datatype_map_numpy[tens_dtype] - tens.values = np.array(data_as_numpy.view(np_dtype).reshape(shape)) + + # TOSA pads the tensor data + shape_elements = shape_num_elements(shape) + values = np.array(data_as_numpy.view(np_dtype)) + values = values[0:shape_elements] + tens.values = values.reshape(shape) else: # int48 is only expected as an accumulated data/output format, int4 not supported print(f"Error: unsupported/unexpected Tensor type {dtype}, with data") |