aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 031c57f..612e8aa 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -282,13 +282,12 @@ BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
template <class T>
mlir::DenseElementsAttr
BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder,
+ mlir::RankedTensorType &type,
const std::vector<T> &values) {
llvm::SmallVector<int32_t> vec;
for (auto val : values) {
vec.push_back(val);
}
- auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())},
- op_builder->getI32Type());
return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
}
@@ -932,6 +931,8 @@ template <>
std::vector<mlir::Value>
TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType input_type =
+ tensor_type_map->at(op->GetInputTensorNames()[0]);
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
@@ -939,9 +940,11 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
- auto padding_attr = BuildDenseI32ElementsAttr(op_builder, attr->padding());
- auto padding_type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(attr->padding().size())}, op_builder->getI32Type());
+ // padding has shape {rank(input_type), 2}
+ auto padding_type = mlir::RankedTensorType::get({input_type.getRank(), 2},
+ op_builder->getI32Type());
+ auto padding_attr =
+ BuildDenseI32ElementsAttr(op_builder, padding_type, attr->padding());
mlir::Operation *mlir_const_op =
op_builder->create<mlir::tosa::ConstOp>(loc, padding_type, padding_attr);
block->push_back(mlir_const_op);
@@ -1024,7 +1027,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(
auto const_type = mlir::RankedTensorType::get(
{static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type());
mlir::DenseElementsAttr const_attr =
- BuildDenseI32ElementsAttr(op_builder, perms_values);
+ BuildDenseI32ElementsAttr(op_builder, const_type, perms_values);
mlir::Operation *mlir_const_op =
op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
auto perms_val = mlir_const_op->getResult(0);