aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaSerialize.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 42ca41e..4bc67ea 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -878,10 +878,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
+ auto pad_op = llvm::cast<mlir::tosa::PadOp>(op);
// Match padding tensor as compile-time constant attribute
mlir::ElementsAttr paddings_elems;
- if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems)))
+ if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings_elems)))
return nullptr;
std::vector<int> paddings;
@@ -889,8 +890,29 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
paddings.push_back(val);
}
- TosaPadAttribute attribute(paddings, 0 /* pad_const_int */,
- 0.0f /* pad_const_fp */);
+ auto quant_info = pad_op.getQuantizationInfoAttr();
+ // pad_const includes the zero point if the tensor uses a zero point.
+ int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0;
+ float pad_const_fp = 0.f;
+
+ if (auto tensor = pad_op.getPadConst()) {
+ // Match pad_const tensor as compile-time constant attribute if present.
+ mlir::DenseElementsAttr attr;
+ if (!matchPattern(tensor, m_Constant(&attr)))
+ return nullptr;
+
+ assert(attr.getNumElements() == 1);
+ auto elementTy = attr.getElementType();
+
+ if (elementTy.isa<mlir::IntegerType>()) {
+ pad_const_int = quant_info ? *attr.value_begin<int8_t>()
+ : *attr.value_begin<int32_t>();
+ } else if (elementTy.isa<mlir::FloatType>()) {
+ pad_const_fp = *attr.value_begin<float>();
+ }
+ }
+
+ TosaPadAttribute attribute(paddings, pad_const_int, pad_const_fp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,