From 909d4d159ee12c6bc8113974d76f46249b6fd7fb Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 8 Mar 2024 18:32:46 +0000 Subject: [tosa_mlir_translator] Use new Clamp and Pad fbs attributes This implements changes required for new Tosa Flatbuffer schema where Clamp and Pad attributes have changed to use ubyte arrays to store int or float values. Signed-off-by: Tai Ly Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89 --- src/TosaDeserialize.cpp | 78 +++++++++++++++++++++++++++++-------------- src/TosaSerialize.cpp | 47 ++++++++++++++++---------- third_party/serialization_lib | 2 +- 3 files changed, 83 insertions(+), 44 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d69f005..b12d22a 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Attribute min_val_attr, max_val_attr; if (input_element_type.isa()) { - min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp()); - max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp()); - } else if (input_element_type.isUnsignedInteger()) { - if (input_element_type.isUnsignedInteger(8)) { - uint8_t min_val = attr->min_int(); - uint8_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else if (input_element_type.isUnsignedInteger(16)) { - uint16_t min_val = attr->min_int(); - uint16_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else { - llvm::errs() << "ERROR: " << get_string(op) - << " contains unsupported unsigned int element data type.\n"; - return {}; - } - } else { + std::vector min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1, + min_float_data); + TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1, + max_float_data); min_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->min_int()); + op_builder->getFloatAttr(input_element_type, min_float_data[0]); max_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->max_int()); + op_builder->getFloatAttr(input_element_type, max_float_data[0]); + } else { + std::vector min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1, + min_int_data); + TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1, + max_int_data); + if (input_element_type.isUnsignedInteger()) { + if (input_element_type.isUnsignedInteger(8)) { + uint8_t min_val = min_int_data[0]; + uint8_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else if (input_element_type.isUnsignedInteger(16)) { + uint16_t min_val = min_int_data[0]; + uint16_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else { + llvm::errs() + << "ERROR: " << get_string(op) + << " contains unsupported unsigned int element data type.\n"; + return {}; + } + } else { + min_val_attr = + op_builder->getIntegerAttr(input_element_type, min_int_data[0]); + max_val_attr = + op_builder->getIntegerAttr(input_element_type, max_int_data[0]); + } } mlir::Operation *mlir_op = op_builder->create( @@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast().getElementType(); assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast(op->GetAttribute()); - auto pad_const_int = attr->pad_const_int(); - auto pad_const_fp = attr->pad_const_fp(); + float pad_const_fp = 0.0f; + int32_t pad_const_int = 0; + + if (element_type.isa()) { + std::vector float_data; + TosaSerializationHandler::ConvertU8toF32(attr->pad_const(), + /* size = */ 1, float_data); + pad_const_fp = float_data[0]; + } else { + std::vector int32_data; + TosaSerializationHandler::ConvertU8toI32(attr->pad_const(), + /* size = */ 1, int32_data); + pad_const_int = int32_data[0]; + } + // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; mlir::Value pad_const_value; - const auto element_type = - input_val.getType().cast().getElementType(); bool isBoolType = element_type.isInteger(1); // First handle boolean type. diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 54a1d28..55a11fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { auto min_val_attr = op.getAttr("min_val"); auto max_val_attr = op.getAttr("max_val"); - float min_fp = 0; - float max_fp = 0; - int32_t min_int = 0; - int32_t max_int = 0; mlir::Type input_element_type = llvm::cast(op.getOperand(0).getType()).getElementType(); @@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build( input_element_type = quantType.getStorageType(); } + std::vector min_val, max_val; if (input_element_type.isa()) { - min_fp = + auto min_fp = mlir::cast(min_val_attr).getValue().convertToFloat(); - max_fp = + auto max_fp = mlir::cast(max_val_attr).getValue().convertToFloat(); - } else if (input_element_type.isUnsignedInteger()) { - min_int = mlir::cast(min_val_attr).getUInt(); - max_int = mlir::cast(max_val_attr).getUInt(); + TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val); + TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val); } else { - assert(input_element_type.isa()); - min_int = mlir::cast(min_val_attr).getInt(); - max_int = mlir::cast(max_val_attr).getInt(); + int32_t min_int = 0; + int32_t max_int = 0; + if (input_element_type.isUnsignedInteger()) { + min_int = mlir::cast(min_val_attr).getUInt(); + max_int = mlir::cast(max_val_attr).getUInt(); + } else { + assert(input_element_type.isa()); + min_int = mlir::cast(min_val_attr).getInt(); + max_int = mlir::cast(max_val_attr).getInt(); + } + TosaSerializationHandler::ConvertI32toU8({min_int}, min_val); + TosaSerializationHandler::ConvertI32toU8({max_int}, max_val); } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + TosaClampAttribute attribute(min_val, max_val); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CLAMP, Attribute_ClampAttribute, &attribute, @@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build( } } - TosaPadAttribute attribute({}, pad_const_int, pad_const_fp); + std::vector pad_const; + mlir::Type input_element_type = + llvm::cast(op.getOperand(0).getType()).getElementType(); + if (input_element_type.isa()) { + TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const); + } else { + TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const); + } + TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, @@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build( std::string shift_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(output); - TosaRescaleAttribute attribute(input_zp, output_zp, - /* multiplier = */ {}, /* shift = */ {}, - scale32, double_round, per_channel, - input_unsigned, output_unsigned); + TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round, + per_channel, input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 758e73e..0b6d7c2 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 758e73e117c5cef17f8f0b1c543efc1df953b2fa +Subproject commit 0b6d7c271af1e6593e6a2cf14b32acea765f4b64 -- cgit v1.2.1