aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-12-07 15:38:01 +0000
committerEric Kunze <eric.kunze@arm.com>2023-01-18 01:04:54 +0000
commitc15f7d52aa4f360eba2344449baa418b7608ac7c (patch)
treeb0322cb02004e9e0a325c847c0bd332051b8389b /python/serializer/tosa_serializer.py
parent5e268097917825ddaa00a86ee95a4a6c4f50124b (diff)
downloadserialization_lib-c15f7d52aa4f360eba2344449baa418b7608ac7c.tar.gz
Schema changes for CLAMP, PAD float attributes
* Float attributes now serialized as uint8 vectors, but treated as floats at input/output to serialization Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I417b0fabe0ef11fea263fe937b57d49bbfdb00da
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py38
1 files changed, 29 insertions, 9 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index e8311ce..f579df2 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -13,10 +13,11 @@
# limitations under the License.
import os
+import struct
+import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -197,7 +198,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddWeightZp, weight_zp))
self.ints.append((a.AddAccumDtype, accum_dtype))
- def PadAttribute(self, padding, pad_const_int, pad_const_fp):
+ def PadAttribute(self, serializer_builder, padding, pad_const_int, pad_const_fp):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
@@ -205,7 +206,14 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddPadding, padding))
self.ints.append((a.AddPadConstInt, pad_const_int))
- self.floats.append((a.AddPadConstFp, pad_const_fp))
+
+ # pad_const_fp attribute serialized as uint8 vector
+ pad_const_float_as_bytes = struct.pack("<f", pad_const_fp)
+ serialized_pad_const_fp = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, pad_const_float_as_bytes
+ )
+
+ self.floats.append((a.AddPadConstFp, serialized_pad_const_fp))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -251,7 +259,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(self, minint, maxint, minfp, maxfp):
+ def ClampAttribute(self, serializer_builder, minint, maxint, minfp, maxfp):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
@@ -260,8 +268,18 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddMinInt, minint))
self.ints.append((a.AddMaxInt, maxint))
- self.ints.append((a.AddMinFp, minfp))
- self.ints.append((a.AddMaxFp, maxfp))
+ # min/max float attributes serialized as uint8 vectors
+ minfp_bytes = struct.pack("<f", minfp)
+ maxfp_bytes = struct.pack("<f", maxfp)
+ serialized_minfp_bytes = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, minfp_bytes
+ )
+ serialized_maxfp_bytes = ts.TosaSerializer.serializeUint8Vec(
+ serializer_builder, maxfp_bytes
+ )
+
+ self.floats.append((a.AddMinFp, serialized_minfp_bytes))
+ self.floats.append((a.AddMaxFp, serialized_maxfp_bytes))
def RescaleAttribute(
self, input_zp, output_zp, multiplier, shift, scale32, double_round, per_channel
@@ -477,9 +495,11 @@ class TosaSerializerTensor:
np_arr = np.array(self.data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
- for val in self.data:
- b = struct.pack("!f", val)
- u8_data.extend([b[3], b[2], b[1], b[0]])
+ # for val in self.data:
+ # b = struct.pack("!f", val)
+ # u8_data.extend([b[3], b[2], b[1], b[0]])
+ np_arr = np.array(self.data, dtype=np.float32)
+ u8_data.extend(np_arr.view(np.uint8))
elif self.dtype == TosaDType.DType:
# Serialize DType enum data as uint8 bytes
for val in self.data: