diff options
-rw-r--r-- | src/TosaSerialize.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 42ca41e..4bc67ea 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -878,10 +878,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); + auto pad_op = llvm::cast<mlir::tosa::PadOp>(op); // Match padding tensor as compile-time constant attribute mlir::ElementsAttr paddings_elems; - if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings_elems))) return nullptr; std::vector<int> paddings; @@ -889,8 +890,29 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( paddings.push_back(val); } - TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, - 0.0f /* pad_const_fp */); + auto quant_info = pad_op.getQuantizationInfoAttr(); + // pad_const includes the zero point if the tensor uses a zero point. + int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; + float pad_const_fp = 0.f; + + if (auto tensor = pad_op.getPadConst()) { + // Match pad_const tensor as compile-time constant attribute if present. + mlir::DenseElementsAttr attr; + if (!matchPattern(tensor, m_Constant(&attr))) + return nullptr; + + assert(attr.getNumElements() == 1); + auto elementTy = attr.getElementType(); + + if (elementTy.isa<mlir::IntegerType>()) { + pad_const_int = quant_info ? *attr.value_begin<int8_t>() + : *attr.value_begin<int32_t>(); + } else if (elementTy.isa<mlir::FloatType>()) { + pad_const_fp = *attr.value_begin<float>(); + } + } + + TosaPadAttribute attribute(paddings, pad_const_int, pad_const_fp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, |