aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaDeserialize.cpp6
-rw-r--r--src/TosaSerialize.cpp16
2 files changed, 12 insertions, 10 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index a4b7eda..153f16f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -809,10 +809,12 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
padding_value);
} else {
// create a const value for pad_const input
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
mlir::Value pad_const_value;
if (pad_const_int != 0) {
auto pad_const_int_type =
- mlir::RankedTensorType::get({}, op_builder->getI32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_int_attr =
mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int});
mlir::Operation *pad_const_int_op =
@@ -822,7 +824,7 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
pad_const_value = pad_const_int_op->getResult(0);
} else if (pad_const_fp != 0) {
auto pad_const_fp_type =
- mlir::RankedTensorType::get({}, op_builder->getF32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_fp_attr =
mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp});
mlir::Operation *pad_const_fp_op =
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index a3e21f9..fec9f17 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -470,8 +470,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
}
} else if (val_attr) {
data.push_back((float)val_attr.getValueAsDouble());
@@ -931,9 +931,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
paddings.push_back(val);
}
- auto quant_info = pad_op.getQuantizationInfoAttr();
- // pad_const includes the zero point if the tensor uses a zero point.
- int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0;
+ int32_t pad_const_int = 0;
float pad_const_fp = 0.f;
if (auto tensor = pad_op.getPadConst()) {
@@ -946,10 +944,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
auto elementTy = attr.getElementType();
if (elementTy.isa<mlir::IntegerType>()) {
- pad_const_int = quant_info ? *attr.value_begin<int8_t>()
- : *attr.value_begin<int32_t>();
+ pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue();
} else if (elementTy.isa<mlir::FloatType>()) {
- pad_const_fp = *attr.value_begin<float>();
+ pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat();
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
}
}