aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-07-07 23:36:04 +0000
committerEric Kunze <eric.kunze@arm.com>2023-07-12 15:20:51 +0000
commit787606544db8664ff42a31958f554fe7088427b5 (patch)
tree1abcacc8fb205bf313c66c197c3e99fa0764818b
parent9a57b9fe6f9832fa0406daac367fd3fc09afa018 (diff)
downloadtosa_mlir_translator-787606544db8664ff42a31958f554fe7088427b5.tar.gz
[tosa_mlir_translator] Fix Pad const attr type
This fixes serialization and deserialization of Pad Operator's constant value attributes that is integer data types other than I32 also fixed serialization of F16 constants to not crash, found in testing Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ic3987aad0e11b612de591eaeecd308d599d174e1
-rw-r--r--src/TosaDeserialize.cpp6
-rw-r--r--src/TosaSerialize.cpp16
2 files changed, 12 insertions, 10 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index a4b7eda..153f16f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -809,10 +809,12 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
padding_value);
} else {
// create a const value for pad_const input
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
mlir::Value pad_const_value;
if (pad_const_int != 0) {
auto pad_const_int_type =
- mlir::RankedTensorType::get({}, op_builder->getI32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_int_attr =
mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int});
mlir::Operation *pad_const_int_op =
@@ -822,7 +824,7 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
pad_const_value = pad_const_int_op->getResult(0);
} else if (pad_const_fp != 0) {
auto pad_const_fp_type =
- mlir::RankedTensorType::get({}, op_builder->getF32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_fp_attr =
mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp});
mlir::Operation *pad_const_fp_op =
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index a3e21f9..fec9f17 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -470,8 +470,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
}
} else if (val_attr) {
data.push_back((float)val_attr.getValueAsDouble());
@@ -931,9 +931,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
paddings.push_back(val);
}
- auto quant_info = pad_op.getQuantizationInfoAttr();
- // pad_const includes the zero point if the tensor uses a zero point.
- int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0;
+ int32_t pad_const_int = 0;
float pad_const_fp = 0.f;
if (auto tensor = pad_op.getPadConst()) {
@@ -946,10 +944,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
auto elementTy = attr.getElementType();
if (elementTy.isa<mlir::IntegerType>()) {
- pad_const_int = quant_info ? *attr.value_begin<int8_t>()
- : *attr.value_begin<int32_t>();
+ pad_const_int = (attr.getValues<mlir::APInt>()[0]).getSExtValue();
} else if (elementTy.isa<mlir::FloatType>()) {
- pad_const_fp = *attr.value_begin<float>();
+ pad_const_fp = (attr.getValues<mlir::APFloat>()[0]).convertToFloat();
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
}
}