From c92710d7259558fb0cd9e9b38d0c78da21c6e2d4 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 15 Sep 2022 12:16:07 +0100 Subject: Fix for CONST floats always truncated to integer values. Add way of saving CONSTs to numpy file for testing purposes. Signed-off-by: Jeremy Johnson Change-Id: Ic3a22da8ec3432037832374090b7ceff345d48de --- python/serializer/tosa_serializer.py | 20 ++++++++++++++++---- python/tosa/Version.py | 3 ++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index a71f3c9..acec4b7 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -362,12 +362,17 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype + if dtype == DType.FLOAT: + fntype = np.float32 + else: + fntype = int + if isinstance(data, np.ndarray): - data = data.flatten().astype(int).tolist() - data = list(map(int, data)) + data = data.flatten().astype(fntype).tolist() + data = list(map(fntype, data)) self.data = data elif isinstance(data, list): - data = list(map(int, data)) + data = list(map(fntype, data)) self.data = data else: self.data = None @@ -569,7 +574,7 @@ class TensorDir(IntEnum): class TosaSerializer: - def __init__(self, pathPrefix): + def __init__(self, pathPrefix, saveConstsToFile=False): self.add_compat_methods() # Get the global TOSA version if not already defined @@ -579,6 +584,9 @@ class TosaSerializer: self.startBasicBlock("main") self.pathPrefix = pathPrefix + # Enables inspection of constant data outside of graph + self.saveConstsToFile = saveConstsToFile + # Indicies used for adding/naming tensors self.currInputIdx = 0 self.currConstIdx = 0 @@ -624,6 +632,10 @@ class TosaSerializer: # Add the operator now self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name) + if self.saveConstsToFile: + filename = "{}.npy".format(name) + np.save(os.path.join(self.pathPrefix, filename), vals, False) + return tens def addIntermediate(self, shape, dtype): diff --git a/python/tosa/Version.py b/python/tosa/Version.py index 0360695..1eaa351 100644 --- a/python/tosa/Version.py +++ b/python/tosa/Version.py @@ -59,6 +59,7 @@ class Version(object): def VersionStart(builder): builder.StartObject(4) def Start(builder): return VersionStart(builder) +def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0) def Add_major(builder, Major): return VersionAdd_major(builder, Major) def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 41) @@ -72,4 +73,4 @@ def Add_draft(builder, Draft): return VersionAdd_draft(builder, Draft) def VersionEnd(builder): return builder.EndObject() def End(builder): - return VersionEnd(builder) + return VersionEnd(builder) \ No newline at end of file -- cgit v1.2.1