aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r--ethosu/vela/tosa_reader.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index dfed035d..eb317169 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -30,6 +30,7 @@ from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
from .tensor import QuantizationParameters
+from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
from .tosa.TosaGraph import TosaGraph as TG
@@ -135,6 +136,22 @@ class TosaSubgraph:
if attr_serializer is not None:
op.attrs = attr_serializer.deserialize(op_data)
+ if "padding" in op.attrs:
+ padding = op.attrs["padding"] # [top, bottom, left, right]
+ op.attrs["explicit_padding"] = (
+ padding[0],
+ padding[2],
+ padding[1],
+ padding[3],
+ ) # [top, left, bottom, right]
+ if "stride" in op.attrs:
+ stride = op.attrs["stride"]
+ if len(stride) == 2:
+ op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ else:
+ # TODO CONV3D more to be done....
+ print("Unsupported kernel dimensions: ", len(stride))
+ assert False
if "dilation" in op.attrs:
dilation = op.attrs["dilation"]
if len(dilation) == 2:
@@ -160,7 +177,7 @@ class TosaSubgraph:
self.set_tensor_zp(op.ifm, quant_info["input_zp"])
if "weight_zp" in quant_info:
self.set_tensor_zp(op.weights, quant_info["weight_zp"])
- if "ouput_zp" in quant_info:
+ if "output_zp" in quant_info:
self.set_tensor_zp(op.ofm, quant_info["output_zp"])
if "a_zp" in quant_info:
self.set_tensor_zp(op.ifm, quant_info["a_zp"])
@@ -194,7 +211,12 @@ class TosaSubgraph:
data_as_numpy = tens_data.DataAsNumpy()
if tens_dtype in datatype_map_numpy:
np_dtype = datatype_map_numpy[tens_dtype]
- tens.values = np.array(data_as_numpy.view(np_dtype).reshape(shape))
+
+ # TOSA pads the tensor data
+ shape_elements = shape_num_elements(shape)
+ values = np.array(data_as_numpy.view(np_dtype))
+ values = values[0:shape_elements]
+ tens.values = values.reshape(shape)
else:
# int48 is only expected as an accumulated data/output format, int4 not supported
print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")