diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 78 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 47 | ||||
m--------- | third_party/serialization_lib | 0 |
3 files changed, 82 insertions, 43 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d69f005..b12d22a 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { mlir::Attribute min_val_attr, max_val_attr; if (input_element_type.isa<mlir::FloatType>()) { - min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp()); - max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp()); - } else if (input_element_type.isUnsignedInteger()) { - if (input_element_type.isUnsignedInteger(8)) { - uint8_t min_val = attr->min_int(); - uint8_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else if (input_element_type.isUnsignedInteger(16)) { - uint16_t min_val = attr->min_int(); - uint16_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else { - llvm::errs() << "ERROR: " << get_string(op) - << " contains unsupported unsigned int element data type.\n"; - return {}; - } - } else { + std::vector<float> min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1, + min_float_data); + TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1, + max_float_data); min_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->min_int()); + op_builder->getFloatAttr(input_element_type, min_float_data[0]); max_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->max_int()); + op_builder->getFloatAttr(input_element_type, max_float_data[0]); + } else { + std::vector<int32_t> min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1, + min_int_data); + TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1, + max_int_data); + if (input_element_type.isUnsignedInteger()) { + if (input_element_type.isUnsignedInteger(8)) { + uint8_t min_val = min_int_data[0]; + uint8_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else if (input_element_type.isUnsignedInteger(16)) { + uint16_t min_val = min_int_data[0]; + uint16_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else { + llvm::errs() + << "ERROR: " << get_string(op) + << " contains unsupported unsigned int element data type.\n"; + return {}; + } + } else { + min_val_attr = + op_builder->getIntegerAttr(input_element_type, min_int_data[0]); + max_val_attr = + op_builder->getIntegerAttr(input_element_type, max_int_data[0]); + } } mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>( @@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); - auto pad_const_int = attr->pad_const_int(); - auto pad_const_fp = attr->pad_const_fp(); + float pad_const_fp = 0.0f; + int32_t pad_const_int = 0; + + if (element_type.isa<mlir::FloatType>()) { + std::vector<float> float_data; + TosaSerializationHandler::ConvertU8toF32(attr->pad_const(), + /* size = */ 1, float_data); + pad_const_fp = float_data[0]; + } else { + std::vector<int32_t> int32_data; + TosaSerializationHandler::ConvertU8toI32(attr->pad_const(), + /* size = */ 1, int32_data); + pad_const_int = int32_data[0]; + } + // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; mlir::Value pad_const_value; - const auto element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); bool isBoolType = element_type.isInteger(1); // First handle boolean type. diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 54a1d28..55a11fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( mlir::Operation &op) const { auto min_val_attr = op.getAttr("min_val"); auto max_val_attr = op.getAttr("max_val"); - float min_fp = 0; - float max_fp = 0; - int32_t min_int = 0; - int32_t max_int = 0; mlir::Type input_element_type = llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); @@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( input_element_type = quantType.getStorageType(); } + std::vector<uint8_t> min_val, max_val; if (input_element_type.isa<mlir::FloatType>()) { - min_fp = + auto min_fp = mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat(); - max_fp = + auto max_fp = mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat(); - } else if (input_element_type.isUnsignedInteger()) { - min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); - max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val); + TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val); } else { - assert(input_element_type.isa<mlir::IntegerType>()); - min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); - max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + int32_t min_int = 0; + int32_t max_int = 0; + if (input_element_type.isUnsignedInteger()) { + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + } else { + assert(input_element_type.isa<mlir::IntegerType>()); + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + } + TosaSerializationHandler::ConvertI32toU8({min_int}, min_val); + TosaSerializationHandler::ConvertI32toU8({max_int}, max_val); } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); + TosaClampAttribute attribute(min_val, max_val); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CLAMP, Attribute_ClampAttribute, &attribute, @@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( } } - TosaPadAttribute attribute({}, pad_const_int, pad_const_fp); + std::vector<uint8_t> pad_const; + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); + if (input_element_type.isa<mlir::FloatType>()) { + TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const); + } else { + TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const); + } + TosaPadAttribute attribute(pad_const); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, @@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( std::string shift_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(output); - TosaRescaleAttribute attribute(input_zp, output_zp, - /* multiplier = */ {}, /* shift = */ {}, - scale32, double_round, per_channel, - input_unsigned, output_unsigned); + TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round, + per_channel, input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 758e73e117c5cef17f8f0b1c543efc1df953b2f +Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6 |