diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2023-02-13 16:02:00 -0800 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2023-02-14 09:22:38 -0800 |
commit | 6bbd8403fd524acec1bf3e63314bf38040802968 (patch) | |
tree | ea50d0db4576d771e3006989253a74f04548fd8d /src | |
parent | f514588912a73bd5eafa49008730ddfc9b3969be (diff) | |
download | tosa_mlir_translator-6bbd8403fd524acec1bf3e63314bf38040802968.tar.gz |
Revert "Align the type of padding and pad_const with the spec"
This reverts commit e1dcee57d8f89cd192411bbec9e8a97b26833bb7.
Change-Id: I7b4c1b12da6530514f9cf09ff7b1e463834fa008
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaSerialize.cpp | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 31f1cba..42ca41e 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -879,15 +879,18 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto padding = - op.getAttr("padding").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef(); - auto pad_const = - op.getAttr("pad_const").dyn_cast<mlir::DenseIntOrFPElementsAttr>(); + // Match padding tensor as compile-time constant attribute + mlir::ElementsAttr paddings_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + return nullptr; - assert(pad_const.getNumElements() == 1); + std::vector<int> paddings; + for (int32_t val : paddings_elems.getValues<int32_t>()) { + paddings.push_back(val); + } - TosaPadAttribute attribute(padding, *pad_const.value_begin<int32_t>(), - *pad_const.value_begin<float>()); + TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, + 0.0f /* pad_const_fp */); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, @@ -904,7 +907,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>( std::string output_name = GetTensorName(op.getResult(0)); // Match perm tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr perm_elems; if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) return nullptr; @@ -1093,7 +1095,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>( std::string output_name = GetTensorName(op.getResult(0)); // Match table tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr table_elems; if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) return nullptr; |