From 497ab5d8cfce56eaa15db8853c903ef4dcf13e42 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Fri, 21 Oct 2022 16:39:01 -0700 Subject: Add attributes to serialization of custom op Change-Id: I4a1d7675b247efcf34aadd59eac17b966e3705af Signed-off-by: Eric Kunze --- include/attribute.def | 7 ++- include/tosa_generated.h | 110 ++++++++++++++++++++++++++++++++++++++++++++--- python/tosa/Attribute.py | 1 + schema/tosa.fbs | 11 ++++- 4 files changed, 121 insertions(+), 8 deletions(-) diff --git a/include/attribute.def b/include/attribute.def index ebbf024..121bd89 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -122,3 +122,8 @@ DEF_ATTRIBUTE(FullyConnected, 3, DEF_ATTRIBUTE(Negate, 2, int32_t, S, input1_zp, int32_t, S, output_zp) + +DEF_ATTRIBUTE(Custom, 3, + string, S, identifier, + string, S, config, + uint8_t, V, implementation_attrs) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 4d231b0..bb501be 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -68,6 +68,9 @@ struct FullyConnectedAttributeBuilder; struct NegateAttribute; struct NegateAttributeBuilder; +struct CustomAttribute; +struct CustomAttributeBuilder; + struct Version; struct VersionBuilder; @@ -432,11 +435,12 @@ enum Attribute : uint8_t { Attribute_MatMulAttribute = 18, Attribute_FullyConnectedAttribute = 19, Attribute_NegateAttribute = 20, + Attribute_CustomAttribute = 21, Attribute_MIN = Attribute_NONE, - Attribute_MAX = Attribute_NegateAttribute + Attribute_MAX = Attribute_CustomAttribute }; -inline const Attribute (&EnumValuesAttribute())[21] { +inline const Attribute (&EnumValuesAttribute())[22] { static const Attribute values[] = { Attribute_NONE, Attribute_PoolAttribute, @@ -458,13 +462,14 @@ inline const Attribute (&EnumValuesAttribute())[21] { Attribute_TableAttribute, Attribute_MatMulAttribute, Attribute_FullyConnectedAttribute, - Attribute_NegateAttribute + Attribute_NegateAttribute, + Attribute_CustomAttribute }; return values; } inline const char * const *EnumNamesAttribute() { - static const char * const names[22] = { + static const char * const names[23] = { "NONE", "PoolAttribute", "ConvAttribute", @@ -486,13 +491,14 @@ inline const char * const *EnumNamesAttribute() { "MatMulAttribute", "FullyConnectedAttribute", "NegateAttribute", + "CustomAttribute", nullptr }; return names; } inline const char *EnumNameAttribute(Attribute e) { - if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_NegateAttribute)) return ""; + if (flatbuffers::IsOutRange(e, Attribute_NONE, Attribute_CustomAttribute)) return ""; const size_t index = static_cast(e); return EnumNamesAttribute()[index]; } @@ -581,6 +587,10 @@ template<> struct AttributeTraits { static const Attribute enum_value = Attribute_NegateAttribute; }; +template<> struct AttributeTraits { + static const Attribute enum_value = Attribute_CustomAttribute; +}; + bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type); bool VerifyAttributeVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); @@ -1986,6 +1996,85 @@ inline flatbuffers::Offset CreateNegateAttribute( return builder_.Finish(); } +struct CustomAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CustomAttributeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_IDENTIFIER = 4, + VT_CONFIG = 6, + VT_IMPLEMENTATION_ATTRS = 8 + }; + const flatbuffers::String *identifier() const { + return GetPointer(VT_IDENTIFIER); + } + const flatbuffers::String *config() const { + return GetPointer(VT_CONFIG); + } + const flatbuffers::Vector *implementation_attrs() const { + return GetPointer *>(VT_IMPLEMENTATION_ATTRS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_IDENTIFIER) && + verifier.VerifyString(identifier()) && + VerifyOffset(verifier, VT_CONFIG) && + verifier.VerifyString(config()) && + VerifyOffset(verifier, VT_IMPLEMENTATION_ATTRS) && + verifier.VerifyVector(implementation_attrs()) && + verifier.EndTable(); + } +}; + +struct CustomAttributeBuilder { + typedef CustomAttribute Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_identifier(flatbuffers::Offset identifier) { + fbb_.AddOffset(CustomAttribute::VT_IDENTIFIER, identifier); + } + void add_config(flatbuffers::Offset config) { + fbb_.AddOffset(CustomAttribute::VT_CONFIG, config); + } + void add_implementation_attrs(flatbuffers::Offset> implementation_attrs) { + fbb_.AddOffset(CustomAttribute::VT_IMPLEMENTATION_ATTRS, implementation_attrs); + } + explicit CustomAttributeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCustomAttribute( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset identifier = 0, + flatbuffers::Offset config = 0, + flatbuffers::Offset> implementation_attrs = 0) { + CustomAttributeBuilder builder_(_fbb); + builder_.add_implementation_attrs(implementation_attrs); + builder_.add_config(config); + builder_.add_identifier(identifier); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateCustomAttributeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *identifier = nullptr, + const char *config = nullptr, + const std::vector *implementation_attrs = nullptr) { + auto identifier__ = identifier ? _fbb.CreateString(identifier) : 0; + auto config__ = config ? _fbb.CreateString(config) : 0; + auto implementation_attrs__ = implementation_attrs ? _fbb.CreateVector(*implementation_attrs) : 0; + return tosa::CreateCustomAttribute( + _fbb, + identifier__, + config__, + implementation_attrs__); +} + struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef VersionBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { @@ -2228,6 +2317,9 @@ struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tosa::NegateAttribute *attribute_as_NegateAttribute() const { return attribute_type() == tosa::Attribute_NegateAttribute ? static_cast(attribute()) : nullptr; } + const tosa::CustomAttribute *attribute_as_CustomAttribute() const { + return attribute_type() == tosa::Attribute_CustomAttribute ? static_cast(attribute()) : nullptr; + } const flatbuffers::Vector> *inputs() const { return GetPointer> *>(VT_INPUTS); } @@ -2330,6 +2422,10 @@ template<> inline const tosa::NegateAttribute *TosaOperator::attribute_as inline const tosa::CustomAttribute *TosaOperator::attribute_as() const { + return attribute_as_CustomAttribute(); +} + struct TosaOperatorBuilder { typedef TosaOperator Table; flatbuffers::FlatBufferBuilder &fbb_; @@ -2721,6 +2817,10 @@ inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, At auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case Attribute_CustomAttribute: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } diff --git a/python/tosa/Attribute.py b/python/tosa/Attribute.py index ec0c7b0..7ada553 100644 --- a/python/tosa/Attribute.py +++ b/python/tosa/Attribute.py @@ -24,3 +24,4 @@ class Attribute(object): MatMulAttribute = 18 FullyConnectedAttribute = 19 NegateAttribute = 20 + CustomAttribute = 21 diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 4d5c611..093c235 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -161,7 +161,8 @@ union Attribute { TableAttribute, MatMulAttribute, FullyConnectedAttribute, - NegateAttribute + NegateAttribute, + CustomAttribute } table PoolAttribute { @@ -281,6 +282,12 @@ table NegateAttribute { output_zp: int32; } +table CustomAttribute { + identifier:string; + config:string; + implementation_attrs:[ubyte]; +} + table Version { _major: int32 = 0; _minor: int32 = 51; -- cgit v1.2.1