aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index a4b7eda..153f16f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -809,10 +809,12 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
padding_value);
} else {
// create a const value for pad_const input
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
mlir::Value pad_const_value;
if (pad_const_int != 0) {
auto pad_const_int_type =
- mlir::RankedTensorType::get({}, op_builder->getI32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_int_attr =
mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int});
mlir::Operation *pad_const_int_op =
@@ -822,7 +824,7 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
pad_const_value = pad_const_int_op->getResult(0);
} else if (pad_const_fp != 0) {
auto pad_const_fp_type =
- mlir::RankedTensorType::get({}, op_builder->getF32Type());
+ mlir::RankedTensorType::get({}, input_element_type);
auto pad_const_fp_attr =
mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp});
mlir::Operation *pad_const_fp_op =