aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2023-02-13 16:02:00 -0800
committerTatWai Chong <tatwai.chong@arm.com>2023-02-14 09:22:38 -0800
commit6bbd8403fd524acec1bf3e63314bf38040802968 (patch)
treeea50d0db4576d771e3006989253a74f04548fd8d /src/TosaSerialize.cpp
parentf514588912a73bd5eafa49008730ddfc9b3969be (diff)
downloadtosa_mlir_translator-6bbd8403fd524acec1bf3e63314bf38040802968.tar.gz
Revert "Align the type of padding and pad_const with the spec"
This reverts commit e1dcee57d8f89cd192411bbec9e8a97b26833bb7. Change-Id: I7b4c1b12da6530514f9cf09ff7b1e463834fa008
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 31f1cba..42ca41e 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -879,15 +879,18 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- auto padding =
- op.getAttr("padding").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
- auto pad_const =
- op.getAttr("pad_const").dyn_cast<mlir::DenseIntOrFPElementsAttr>();
+ // Match padding tensor as compile-time constant attribute
+ mlir::ElementsAttr paddings_elems;
+ if (!matchPattern(op.getOperand(1), m_Constant(&paddings_elems)))
+ return nullptr;
- assert(pad_const.getNumElements() == 1);
+ std::vector<int> paddings;
+ for (int32_t val : paddings_elems.getValues<int32_t>()) {
+ paddings.push_back(val);
+ }
- TosaPadAttribute attribute(padding, *pad_const.value_begin<int32_t>(),
- *pad_const.value_begin<float>());
+ TosaPadAttribute attribute(paddings, 0 /* pad_const_int */,
+ 0.0f /* pad_const_fp */);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
@@ -904,7 +907,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
std::string output_name = GetTensorName(op.getResult(0));
// Match perm tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr perm_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&perm_elems)))
return nullptr;
@@ -1093,7 +1095,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
std::string output_name = GetTensorName(op.getResult(0));
// Match table tensor as compile-time constant attribute
- // TODO: fix when MLIR dialect changes
mlir::ElementsAttr table_elems;
if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
return nullptr;