From ad78daaf0fa1e41742cbed314459c3dbbb483c20 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 13 Mar 2024 18:52:45 +0000 Subject: [serialization_lib] Add acc_type to Conv Attrs This adds acc_type to ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly Change-Id: I73bab71b2eb90f6451fadee21d5bed1811ecbfd7 --- include/attribute.def | 10 +++++---- include/tosa_generated.h | 40 ++++++++++++++++++++++++++++------- python/serializer/tosa_serializer.py | 8 +++++-- python/tosa/ConvAttribute.py | 15 ++++++++++++- python/tosa/TransposeConvAttribute.py | 15 ++++++++++++- schema/tosa.fbs | 2 ++ 6 files changed, 74 insertions(+), 16 deletions(-) diff --git a/include/attribute.def b/include/attribute.def index 2176f47..723543e 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -34,21 +34,23 @@ DEF_ATTRIBUTE(Pool, 6, int32_t, S, output_zp, DType, S, acc_type) -DEF_ATTRIBUTE(Conv, 6, +DEF_ATTRIBUTE(Conv, 7, int32_t, V, pad, int32_t, V, stride, int32_t, V, dilation, int32_t, S, input_zp, int32_t, S, weight_zp, - bool, S, local_bound) + bool, S, local_bound, + DType, S, acc_type) -DEF_ATTRIBUTE(TransposeConv, 6, +DEF_ATTRIBUTE(TransposeConv, 7, int32_t, V, out_pad, int32_t, V, stride, int32_t, V, output_shape, int32_t, S, input_zp, int32_t, S, weight_zp, - bool, S, local_bound) + bool, S, local_bound, + DType, S, acc_type) DEF_ATTRIBUTE(Pad, 1, uint8_t, V, pad_const) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 64d54bc..20f6993 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -759,7 +759,8 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_DILATION = 8, VT_INPUT_ZP = 10, VT_WEIGHT_ZP = 12, - VT_LOCAL_BOUND = 14 + VT_LOCAL_BOUND = 14, + VT_ACC_TYPE = 16 }; const ::flatbuffers::Vector *pad() const { return GetPointer *>(VT_PAD); @@ -779,6 +780,9 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool local_bound() const { return GetField(VT_LOCAL_BOUND, 0) != 0; } + tosa::DType acc_type() const { + return static_cast(GetField(VT_ACC_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD) && @@ -790,6 +794,7 @@ struct ConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyField(verifier, VT_INPUT_ZP, 4) && VerifyField(verifier, VT_WEIGHT_ZP, 4) && VerifyField(verifier, VT_LOCAL_BOUND, 1) && + VerifyField(verifier, VT_ACC_TYPE, 4) && verifier.EndTable(); } }; @@ -816,6 +821,9 @@ struct ConvAttributeBuilder { void add_local_bound(bool local_bound) { fbb_.AddElement(ConvAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); } + void add_acc_type(tosa::DType acc_type) { + fbb_.AddElement(ConvAttribute::VT_ACC_TYPE, static_cast(acc_type), 0); + } explicit ConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -834,8 +842,10 @@ inline ::flatbuffers::Offset CreateConvAttribute( ::flatbuffers::Offset<::flatbuffers::Vector> dilation = 0, int32_t input_zp = 0, int32_t weight_zp = 0, - bool local_bound = false) { + bool local_bound = false, + tosa::DType acc_type = tosa::DType_UNKNOWN) { ConvAttributeBuilder builder_(_fbb); + builder_.add_acc_type(acc_type); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); builder_.add_dilation(dilation); @@ -852,7 +862,8 @@ inline ::flatbuffers::Offset CreateConvAttributeDirect( const std::vector *dilation = nullptr, int32_t input_zp = 0, int32_t weight_zp = 0, - bool local_bound = false) { + bool local_bound = false, + tosa::DType acc_type = tosa::DType_UNKNOWN) { auto pad__ = pad ? _fbb.CreateVector(*pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto dilation__ = dilation ? _fbb.CreateVector(*dilation) : 0; @@ -863,7 +874,8 @@ inline ::flatbuffers::Offset CreateConvAttributeDirect( dilation__, input_zp, weight_zp, - local_bound); + local_bound, + acc_type); } struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -874,7 +886,8 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T VT_OUTPUT_SHAPE = 8, VT_INPUT_ZP = 10, VT_WEIGHT_ZP = 12, - VT_LOCAL_BOUND = 14 + VT_LOCAL_BOUND = 14, + VT_ACC_TYPE = 16 }; const ::flatbuffers::Vector *out_pad() const { return GetPointer *>(VT_OUT_PAD); @@ -894,6 +907,9 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T bool local_bound() const { return GetField(VT_LOCAL_BOUND, 0) != 0; } + tosa::DType acc_type() const { + return static_cast(GetField(VT_ACC_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_OUT_PAD) && @@ -905,6 +921,7 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T VerifyField(verifier, VT_INPUT_ZP, 4) && VerifyField(verifier, VT_WEIGHT_ZP, 4) && VerifyField(verifier, VT_LOCAL_BOUND, 1) && + VerifyField(verifier, VT_ACC_TYPE, 4) && verifier.EndTable(); } }; @@ -931,6 +948,9 @@ struct TransposeConvAttributeBuilder { void add_local_bound(bool local_bound) { fbb_.AddElement(TransposeConvAttribute::VT_LOCAL_BOUND, static_cast(local_bound), 0); } + void add_acc_type(tosa::DType acc_type) { + fbb_.AddElement(TransposeConvAttribute::VT_ACC_TYPE, static_cast(acc_type), 0); + } explicit TransposeConvAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -949,8 +969,10 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut ::flatbuffers::Offset<::flatbuffers::Vector> output_shape = 0, int32_t input_zp = 0, int32_t weight_zp = 0, - bool local_bound = false) { + bool local_bound = false, + tosa::DType acc_type = tosa::DType_UNKNOWN) { TransposeConvAttributeBuilder builder_(_fbb); + builder_.add_acc_type(acc_type); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); builder_.add_output_shape(output_shape); @@ -967,7 +989,8 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut const std::vector *output_shape = nullptr, int32_t input_zp = 0, int32_t weight_zp = 0, - bool local_bound = false) { + bool local_bound = false, + tosa::DType acc_type = tosa::DType_UNKNOWN) { auto out_pad__ = out_pad ? _fbb.CreateVector(*out_pad) : 0; auto stride__ = stride ? _fbb.CreateVector(*stride) : 0; auto output_shape__ = output_shape ? _fbb.CreateVector(*output_shape) : 0; @@ -978,7 +1001,8 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut output_shape__, input_zp, weight_zp, - local_bound); + local_bound, + acc_type); } struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 2c7996a..9658edf 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -172,7 +172,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddOutputZp, output_zp)) self.ints.append((a.AddAccType, acc_type)) - def ConvAttribute(self, pad, stride, dilation, input_zp, weight_zp, local_bound): + def ConvAttribute( + self, pad, stride, dilation, input_zp, weight_zp, local_bound, acc_type + ): from tosa import ConvAttribute as a, Attribute self.utype = Attribute.Attribute().ConvAttribute @@ -184,9 +186,10 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) def TransposeConvAttribute( - self, outpad, stride, output_shape, input_zp, weight_zp, local_bound + self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type ): from tosa import TransposeConvAttribute as a, Attribute @@ -199,6 +202,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) + self.ints.append((a.AddAccType, acc_type)) def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): from tosa import PadAttribute as a, Attribute diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index b35b67c..dfa75dc 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -130,8 +130,15 @@ class ConvAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # ConvAttribute + def AccType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def ConvAttributeStart(builder): - builder.StartObject(6) + builder.StartObject(7) def Start(builder): ConvAttributeStart(builder) @@ -190,6 +197,12 @@ def ConvAttributeAddLocalBound(builder, localBound): def AddLocalBound(builder, localBound): ConvAttributeAddLocalBound(builder, localBound) +def ConvAttributeAddAccType(builder, accType): + builder.PrependUint32Slot(6, accType, 0) + +def AddAccType(builder, accType): + ConvAttributeAddAccType(builder, accType) + def ConvAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index a74a433..e5397a8 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -130,8 +130,15 @@ class TransposeConvAttribute(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # TransposeConvAttribute + def AccType(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def TransposeConvAttributeStart(builder): - builder.StartObject(6) + builder.StartObject(7) def Start(builder): TransposeConvAttributeStart(builder) @@ -190,6 +197,12 @@ def TransposeConvAttributeAddLocalBound(builder, localBound): def AddLocalBound(builder, localBound): TransposeConvAttributeAddLocalBound(builder, localBound) +def TransposeConvAttributeAddAccType(builder, accType): + builder.PrependUint32Slot(6, accType, 0) + +def AddAccType(builder, accType): + TransposeConvAttributeAddAccType(builder, accType) + def TransposeConvAttributeEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 028765d..79b83b1 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -170,6 +170,7 @@ table ConvAttribute { input_zp: int32; weight_zp: int32; local_bound: bool; + acc_type: DType; } table TransposeConvAttribute { @@ -179,6 +180,7 @@ table TransposeConvAttribute { input_zp: int32; weight_zp: int32; local_bound: bool; + acc_type: DType; } table PadAttribute { -- cgit v1.2.1