diff options
-rw-r--r-- | src/TosaSerialize.cpp | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 31f1cba..42ca41e 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -879,15 +879,18 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto padding = - op.getAttr("padding").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef(); - auto pad_const = - op.getAttr("pad_const").dyn_cast<mlir::DenseIntOrFPElementsAttr>(); + // Match padding tensor as compile-time constant attribute + mlir::ElementsAttr paddings_elems; + if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems))) + return nullptr; - assert(pad_const.getNumElements() == 1); + std::vector<int> paddings; + for (int32_t val : paddings_elems.getValues<int32_t>()) { + paddings.push_back(val); + } - TosaPadAttribute attribute(padding, *pad_const.value_begin<int32_t>(), - *pad_const.value_begin<float>()); + TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, + 0.0f /* pad_const_fp */); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_PAD, Attribute_PadAttribute, &attribute, @@ -904,7 +907,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>( std::string output_name = GetTensorName(op.getResult(0)); // Match perm tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr perm_elems; if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems))) return nullptr; @@ -1093,7 +1095,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>( std::string output_name = GetTensorName(op.getResult(0)); // Match table tensor as compile-time constant attribute - // TODO: fix when MLIR dialect changes mlir::ElementsAttr table_elems; if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) return nullptr; |