aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaDeserialize.cpp78
-rw-r--r--src/TosaSerialize.cpp47
m---------third_party/serialization_lib0
3 files changed, 82 insertions, 43 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d69f005..b12d22a 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
mlir::Attribute min_val_attr, max_val_attr;
if (input_element_type.isa<mlir::FloatType>()) {
- min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
- max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
- } else if (input_element_type.isUnsignedInteger()) {
- if (input_element_type.isUnsignedInteger(8)) {
- uint8_t min_val = attr->min_int();
- uint8_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else if (input_element_type.isUnsignedInteger(16)) {
- uint16_t min_val = attr->min_int();
- uint16_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else {
- llvm::errs() << "ERROR: " << get_string(op)
- << " contains unsupported unsigned int element data type.\n";
- return {};
- }
- } else {
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1,
+ min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1,
+ max_float_data);
min_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ op_builder->getFloatAttr(input_element_type, min_float_data[0]);
max_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ op_builder->getFloatAttr(input_element_type, max_float_data[0]);
+ } else {
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1,
+ min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1,
+ max_int_data);
+ if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = min_int_data[0];
+ uint8_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = min_int_data[0];
+ uint16_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, min_int_data[0]);
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, max_int_data[0]);
+ }
}
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
@@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
tensor_type_map->at(op->GetInputTensorNames()[0]);
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
assert(op->GetAttributeType() ==
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
- auto pad_const_int = attr->pad_const_int();
- auto pad_const_fp = attr->pad_const_fp();
+ float pad_const_fp = 0.0f;
+ int32_t pad_const_int = 0;
+
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->pad_const(),
+ /* size = */ 1, float_data);
+ pad_const_fp = float_data[0];
+ } else {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_const_int = int32_data[0];
+ }
+
// todo: int input_zp = attr->pad_input_zp();
mlir::Operation *mlir_op;
mlir::Value pad_const_value;
- const auto element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
bool isBoolType = element_type.isInteger(1);
// First handle boolean type.
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 54a1d28..55a11fd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
auto min_val_attr = op.getAttr("min_val");
auto max_val_attr = op.getAttr("max_val");
- float min_fp = 0;
- float max_fp = 0;
- int32_t min_int = 0;
- int32_t max_int = 0;
mlir::Type input_element_type =
llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
@@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
input_element_type = quantType.getStorageType();
}
+ std::vector<uint8_t> min_val, max_val;
if (input_element_type.isa<mlir::FloatType>()) {
- min_fp =
+ auto min_fp =
mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
- max_fp =
+ auto max_fp =
mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
- } else if (input_element_type.isUnsignedInteger()) {
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val);
+ TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val);
} else {
- assert(input_element_type.isa<mlir::IntegerType>());
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ int32_t min_int = 0;
+ int32_t max_int = 0;
+ if (input_element_type.isUnsignedInteger()) {
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ }
+ TosaSerializationHandler::ConvertI32toU8({min_int}, min_val);
+ TosaSerializationHandler::ConvertI32toU8({max_int}, max_val);
}
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attribute(min_val, max_val);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CLAMP, Attribute_ClampAttribute, &attribute,
@@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
}
}
- TosaPadAttribute attribute({}, pad_const_int, pad_const_fp);
+ std::vector<uint8_t> pad_const;
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (input_element_type.isa<mlir::FloatType>()) {
+ TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const);
+ } else {
+ TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const);
+ }
+ TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
@@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
std::string shift_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(output);
- TosaRescaleAttribute attribute(input_zp, output_zp,
- /* multiplier = */ {}, /* shift = */ {},
- scale32, double_round, per_channel,
- input_unsigned, output_unsigned);
+ TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round,
+ per_channel, input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 758e73e117c5cef17f8f0b1c543efc1df953b2f
+Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6