diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-08 18:32:46 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-03-17 19:55:42 -0700 |
commit | 909d4d159ee12c6bc8113974d76f46249b6fd7fb (patch) | |
tree | 7e6320d8f74ba6478a654404ccc74cca2ff3219f | |
parent | f983e51df5030facfd1c5bf59dcc67a32a1913a8 (diff) | |
download | tosa_mlir_translator-909d4d159ee12c6bc8113974d76f46249b6fd7fb.tar.gz |
[tosa_mlir_translator] Use new Clamp and Pad fbs attributes
This implements changes required for new Tosa Flatbuffer
schema where Clamp and Pad attributes have changed to use
ubyte arrays to store int or float values.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89
-rw-r--r-- | src/TosaDeserialize.cpp | 78 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 47 | ||||
m--------- | third_party/serialization_lib | 0 |
3 files changed, 82 insertions, 43 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d69f005..b12d22a 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { mlir::Attribute min_val_attr, max_val_attr; if (input_element_type.isa<mlir::FloatType>()) { - min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp()); - max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp()); - } else if (input_element_type.isUnsignedInteger()) { - if (input_element_type.isUnsignedInteger(8)) { - uint8_t min_val = attr->min_int(); - uint8_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else if (input_element_type.isUnsignedInteger(16)) { - uint16_t min_val = attr->min_int(); - uint16_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else { - llvm::errs() << "ERROR: " << get_string(op) - << " contains unsupported unsigned int element data type.\n"; - return {}; - } - } else { + std::vector<float> min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1, + min_float_data); + TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1, + max_float_data); min_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->min_int()); + op_builder->getFloatAttr(input_element_type, min_float_data[0]); max_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->max_int()); + op_builder->getFloatAttr(input_element_type, max_float_data[0]); + } else { + std::vector<int32_t> min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1, + min_int_data); + TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1, + max_int_data); + if (input_element_type.isUnsignedInteger()) { + if (input_element_type.isUnsignedInteger(8)) { + uint8_t min_val = min_int_data[0]; + uint8_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else if (input_element_type.isUnsignedInteger(16)) { + uint16_t min_val = min_int_data[0]; + uint16_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else { + llvm::errs() + << "ERROR: " << get_string(op) + << " contains unsupported unsigned int element data type.\n"; + return {}; + } + } else { + min_val_attr = + op_builder->getIntegerAttr(input_element_type, min_int_data[0]); + max_val_attr = + op_builder->getIntegerAttr(input_element_type, max_int_data[0]); + } } mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>( @@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); - auto pad_const_int = attr->pad_const_int(); - auto pad_const_fp = attr->pad_const_fp(); + float pad_const_fp = 0.0f; + int32_t pad_const_int = 0; + + if (element_type.isa<mlir::FloatType>()) { + std::vector<float> float_data; + TosaSerializationHandler::ConvertU8toF32(attr->pad_const(), + /* size = */ 1, float_data); + pad_const_fp = float_data[0]; + } else { + std::vector<int32_t> int32_data; + TosaSerializationHandler::ConvertU8toI32(attr->pad_const(), + /* size = */ 1, int32_data); + pad_const_int = int32_data[0]; + } + // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; mlir::Value pad_const_value; - const auto element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); bool isBoolType = element_type.isInteger(1); // First handle boolean type. diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 54a1d28..55a11fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( mlir::Operation &op) const { auto min_val_attr = op.getAttr("min_val"); auto max_val_attr = op.getAttr("max_val"); - float min_fp = 0; - float max_fp = 0; - int32_t min_int = 0; - int32_t max_int = 0; mlir::Type input_element_type = llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); @@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( input_element_type = quantType.getStorageType(); } + std::vector<uint8_t> min_val, max_val; if (input_element_type.isa<mlir::FloatType>()) { - min_fp = + auto min_fp = mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat(); - max_fp = + auto max_fp = mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat(); - } else if (input_element_type.isUnsignedInteger()) { - min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); - max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val); + TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val); } else { - assert(input_element_type.isa<mlir::IntegerType>()); - min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); - max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + int32_t min_int = 0; + int32_t max_int = 0; + if (input_element_type.isUnsignedInteger()) { + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + } else { + assert(input_element_type.isa<mlir::IntegerType>()); + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + } + TosaSerializationHandler::ConvertI32toU8({min_int}, min_val); + TosaSerializationHandler::ConvertI32toU8({max_int}, max_val); } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + TosaClampAttribute attribute(min_val, max_val); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CLAMP, Attribute_ClampAttribute, &attribute, @@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( } } - TosaPadAttribute attribute({}, pad_const_int, pad_const_fp); + std::vector<uint8_t> pad_const; + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); + if (input_element_type.isa<mlir::FloatType>()) { + TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const); + } else { + TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const); + } + TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, @@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( std::string shift_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(output); - TosaRescaleAttribute attribute(input_zp, output_zp, - /* multiplier = */ {}, /* shift = */ {}, - scale32, double_round, per_channel, - input_unsigned, output_unsigned); + TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round, + per_channel, input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 758e73e117c5cef17f8f0b1c543efc1df953b2f +Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6 |