diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 6 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 16 |
2 files changed, 12 insertions, 10 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index a4b7eda..153f16f 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -809,10 +809,12 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { padding_value); } else { // create a const value for pad_const input + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); mlir::Value pad_const_value; if (pad_const_int != 0) { auto pad_const_int_type = - mlir::RankedTensorType::get({}, op_builder->getI32Type()); + mlir::RankedTensorType::get({}, input_element_type); auto pad_const_int_attr = mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); mlir::Operation *pad_const_int_op = @@ -822,7 +824,7 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { pad_const_value = pad_const_int_op->getResult(0); } else if (pad_const_fp != 0) { auto pad_const_fp_type = - mlir::RankedTensorType::get({}, op_builder->getF32Type()); + mlir::RankedTensorType::get({}, input_element_type); auto pad_const_fp_attr = mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); mlir::Operation *pad_const_fp_op = diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index a3e21f9..fec9f17 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -470,8 +470,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>(); if (dense_attr) { - for (auto val : dense_attr.getValues<float>()) { - data.push_back(val); + for (auto val : dense_attr.getValues<mlir::APFloat>()) { + data.push_back(val.convertToFloat()); } } else if (val_attr) { data.push_back((float)val_attr.getValueAsDouble()); @@ -931,9 +931,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( paddings.push_back(val); } - auto quant_info = pad_op.getQuantizationInfoAttr(); - // pad_const includes the zero point if the tensor uses a zero point. - int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; + int32_t pad_const_int = 0; float pad_const_fp = 0.f; if (auto tensor = pad_op.getPadConst()) { @@ -946,10 +944,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( auto elementTy = attr.getElementType(); if (elementTy.isa<mlir::IntegerType>()) { - pad_const_int = quant_info ? *attr.value_begin<int8_t>() - : *attr.value_begin<int32_t>(); + pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue(); } else if (elementTy.isa<mlir::FloatType>()) { - pad_const_fp = *attr.value_begin<float>(); + pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat(); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; } } |