aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-09-15 12:16:07 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-09-15 12:57:21 +0100
commitc92710d7259558fb0cd9e9b38d0c78da21c6e2d4 (patch)
tree657a3cf2921a8bd8cc9a360cc85eee4ff8aeeb67
parent4381b3d7fcb7cab975f46c62c86a35c53ade047f (diff)
downloadserialization_lib-c92710d7259558fb0cd9e9b38d0c78da21c6e2d4.tar.gz
Fix for CONST floats always truncated to integer values.
Add way of saving CONSTs to numpy file for testing purposes. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ic3a22da8ec3432037832374090b7ceff345d48de
-rw-r--r--python/serializer/tosa_serializer.py20
-rw-r--r--python/tosa/Version.py3
2 files changed, 18 insertions, 5 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index a71f3c9..acec4b7 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -362,12 +362,17 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
+ if dtype == DType.FLOAT:
+ fntype = np.float32
+ else:
+ fntype = int
+
if isinstance(data, np.ndarray):
- data = data.flatten().astype(int).tolist()
- data = list(map(int, data))
+ data = data.flatten().astype(fntype).tolist()
+ data = list(map(fntype, data))
self.data = data
elif isinstance(data, list):
- data = list(map(int, data))
+ data = list(map(fntype, data))
self.data = data
else:
self.data = None
@@ -569,7 +574,7 @@ class TensorDir(IntEnum):
class TosaSerializer:
- def __init__(self, pathPrefix):
+ def __init__(self, pathPrefix, saveConstsToFile=False):
self.add_compat_methods()
# Get the global TOSA version if not already defined
@@ -579,6 +584,9 @@ class TosaSerializer:
self.startBasicBlock("main")
self.pathPrefix = pathPrefix
+ # Enables inspection of constant data outside of graph
+ self.saveConstsToFile = saveConstsToFile
+
# Indicies used for adding/naming tensors
self.currInputIdx = 0
self.currConstIdx = 0
@@ -624,6 +632,10 @@ class TosaSerializer:
# Add the operator now
self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
+ if self.saveConstsToFile:
+ filename = "{}.npy".format(name)
+ np.save(os.path.join(self.pathPrefix, filename), vals, False)
+
return tens
def addIntermediate(self, shape, dtype):
diff --git a/python/tosa/Version.py b/python/tosa/Version.py
index 0360695..1eaa351 100644
--- a/python/tosa/Version.py
+++ b/python/tosa/Version.py
@@ -59,6 +59,7 @@ class Version(object):
def VersionStart(builder): builder.StartObject(4)
def Start(builder):
return VersionStart(builder)
+def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0)
def Add_major(builder, Major):
return VersionAdd_major(builder, Major)
def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 41)
@@ -72,4 +73,4 @@ def Add_draft(builder, Draft):
return VersionAdd_draft(builder, Draft)
def VersionEnd(builder): return builder.EndObject()
def End(builder):
- return VersionEnd(builder)
+ return VersionEnd(builder) \ No newline at end of file