diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2023-02-22 15:35:28 -0800 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2023-02-24 14:49:52 -0800 |
commit | 10fd495c6d3d4034e1a30d2ef65e17a011da58ec (patch) | |
tree | 9ebcf6e437b623fe3a454d7705c8510bd533bcf7 /src | |
parent | 6bbd8403fd524acec1bf3e63314bf38040802968 (diff) | |
download | tosa_mlir_translator-10fd495c6d3d4034e1a30d2ef65e17a011da58ec.tar.gz |
[Fix] Explicit pad const hasn't been read during pad op construction
The op builder didn't read the value from operand but only set the
pad_const attribute to zero.
Change-Id: Ia6f16490c3fad42200884a9d3a5118fa5c152b53
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaSerialize.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 42ca41e..4bc67ea 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -878,10 +878,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); + auto pad_op = llvm::cast<mlir::tosa::PadOp>(op); // Match padding tensor as compile-time constant attribute mlir::ElementsAttr paddings_elems; - if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings_elems))) return nullptr; std::vector<int> paddings; @@ -889,8 +890,29 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( paddings.push_back(val); } - TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, - 0.0f /* pad_const_fp */); + auto quant_info = pad_op.getQuantizationInfoAttr(); + // pad_const includes the zero point if the tensor uses a zero point. + int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; + float pad_const_fp = 0.f; + + if (auto tensor = pad_op.getPadConst()) { + // Match pad_const tensor as compile-time constant attribute if present. + mlir::DenseElementsAttr attr; + if (!matchPattern(tensor, m_Constant(&attr))) + return nullptr; + + assert(attr.getNumElements() == 1); + auto elementTy = attr.getElementType(); + + if (elementTy.isa<mlir::IntegerType>()) { + pad_const_int = quant_info ? *attr.value_begin<int8_t>() + : *attr.value_begin<int32_t>(); + } else if (elementTy.isa<mlir::FloatType>()) { + pad_const_fp = *attr.value_begin<float>(); + } + } + + TosaPadAttribute attribute(paddings, pad_const_int, pad_const_fp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, |