aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_serializer.py
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-17 16:01:59 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-06-24 16:37:58 -0700
commit82507d77056dd5510547438ba2064c1ee8bebc2c (patch)
treeca5c4bd49029430b1153091f30680866e97ccd2f /verif/tosa_serializer.py
parent2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (diff)
downloadreference_model-82507d77056dd5510547438ba2064c1ee8bebc2c.tar.gz
Update to use new serialization_lib API.
- Constant tensors are now initialized from embedded u8 array instead from numpy. - Python unit test generator and built-in test hasn't been updated. Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I5cb86f8e5ec8f23fee5dcbf257874a0f204ede04
Diffstat (limited to 'verif/tosa_serializer.py')
-rw-r--r--verif/tosa_serializer.py82
1 files changed, 66 insertions, 16 deletions
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index 3b03252..c4de2a2 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -19,6 +19,7 @@ import sys
import json
import flatbuffers
import numpy as np
+import struct
from enum import Enum, IntEnum, unique
from tosa import (
TosaGraph,
@@ -41,6 +42,7 @@ import tosa
# With the way flatc generates its python types, there is no programatic way
# to get string names for the integer types. Manually maintain a string table
# here.
+DType = tosa.DType.DType()
DTypeNames = [
"UNKNOWN",
"BOOL",
@@ -53,6 +55,7 @@ DTypeNames = [
"FLOAT",
]
+ByteMask = np.uint64(0xFF)
def dtype_str_to_val(name):
@@ -337,7 +340,7 @@ class TosaSerializerTensor:
name,
shape,
dtype,
- filename=None,
+ data=None,
placeholderFilename=None,
):
self.name = name
@@ -349,8 +352,12 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- # Filename for const tensors. This gets written to the .tosa serialization
- self.filename = filename
+ if isinstance(data, np.ndarray):
+ data = data.flatten().astype(int).tolist()
+ data = list(map(int, data))
+ self.data = data
+ else:
+ self.data = None
# Filename for placeholder tensors. These get generated by the test generation
# process and are written to disk, but are considered input tensors by the network
@@ -359,11 +366,10 @@ class TosaSerializerTensor:
self.placeholderFilename = placeholderFilename
def __str__(self):
- str = "TosaSerializerTensor name: {} shape: {} dtype: {} filename: {}".format(
+ str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format(
self.name,
self.shape,
DTypeNames[self.dtype],
- self.filename,
)
return str
@@ -372,16 +378,56 @@ class TosaSerializerTensor:
def serialize(self, builder):
fb_name = builder.CreateString(self.name)
- if self.filename:
- fb_filename = builder.CreateString(self.filename)
fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
+ if self.data:
+ u8_data = list()
+ # little endianess
+ if self.dtype == DType.BOOL:
+ for val in self.data:
+ val_u8 = np.uint8(val)
+ u8_data.append(val_u8)
+ elif self.dtype == DType.INT8:
+ for val in self.data:
+ val_u8 = np.uint8(val)
+ u8_data.append(val_u8)
+ elif self.dtype == DType.INT16:
+ for val in self.data:
+ val_u16 = np.uint16(val)
+ b0 = val_u16 & ByteMask
+ b1 = (val_u16 >> np.uint16(8)) & ByteMask
+ u8_data.extend([b0, b1])
+ elif self.dtype == DType.INT32:
+ for val in self.data:
+ val_u32 = np.uint32(val)
+ b0 = val_u32 & ByteMask
+ b1 = (val_u32 >> np.uint32(8)) & ByteMask
+ b2 = (val_u32 >> np.uint32(16)) & ByteMask
+ b3 = (val_u32 >> np.uint32(32)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3])
+ elif self.dtype == DType.INT48:
+ for val in self.data:
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5])
+ elif self.dtype == DType.FLOAT:
+ for val in self.data:
+ b = struct.pack('!f', val)
+ u8_data.extend([b[3], b[2], b[1], b[0]])
+ else:
+ raise Exception("unsupported data type {}".format(DTypeNames[self.dtype]))
+ fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
TosaTensor.TosaTensorStart(builder)
TosaTensor.TosaTensorAddName(builder, fb_name)
TosaTensor.TosaTensorAddShape(builder, fb_shapes)
TosaTensor.TosaTensorAddType(builder, self.dtype)
- if self.filename:
- TosaTensor.TosaTensorAddNpyFilename(builder, fb_filename)
+ if self.data:
+ TosaTensor.TosaTensorAddData(builder, fb_data)
return TosaTensor.TosaTensorEnd(builder)
@@ -448,7 +494,7 @@ class TosaSerializerBasicBlock:
name,
shape,
dtype,
- filename=None,
+ data=None,
placeholderFilename=None,
):
try:
@@ -456,7 +502,7 @@ class TosaSerializerBasicBlock:
tens = self.tensors[name]
except KeyError:
self.tensors[name] = TosaSerializerTensor(
- name, shape, dtype, filename, placeholderFilename
+ name, shape, dtype, data, placeholderFilename
)
return self.tensors[name]
@@ -562,12 +608,10 @@ class TosaSerializer:
filename = "{}.npy".format(name)
self.currInputIdx = self.currInputIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, filename)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
# Add the operator now
self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
- if vals is not None:
- np.save(os.path.join(self.pathPrefix, filename), vals, False)
return tens
def addIntermediate(self, shape, dtype):
@@ -576,10 +620,9 @@ class TosaSerializer:
raise Exception("addTensor called without valid basic block")
name = "layer-{}".format(self.currLayerIdx)
- filename = None # No file, so no filename
self.currLayerIdx = self.currLayerIdx + 1
- tens = self.currBasicBlock.addTensor(name, shape, dtype, filename)
+ tens = self.currBasicBlock.addTensor(name, shape, dtype, None)
return tens
@@ -683,6 +726,13 @@ class TosaSerializer:
return builder.EndVector(len(fb_strs))
@staticmethod
+ def serializeUint8Vec(builder, vec):
+ builder.StartVector(1, len(vec), 8)
+ for v in vec[::-1]:
+ builder.PrependUint8(v)
+ return builder.EndVector(len(vec))
+
+ @staticmethod
def serializeInt32Vec(builder, vec):
builder.StartVector(4, len(vec), 4)
for v in vec[::-1]: