diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 031c57f..612e8aa 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -282,13 +282,12 @@ BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, template <class T> mlir::DenseElementsAttr BuildDenseI32ElementsAttr(mlir::OpBuilder *op_builder, + mlir::RankedTensorType &type, const std::vector<T> &values) { llvm::SmallVector<int32_t> vec; for (auto val : values) { vec.push_back(val); } - auto type = mlir::RankedTensorType::get({static_cast<int64_t>(vec.size())}, - op_builder->getI32Type()); return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); } @@ -932,6 +931,8 @@ template <> std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType input_type = + tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); @@ -939,9 +940,11 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); - auto padding_attr = BuildDenseI32ElementsAttr(op_builder, attr->padding()); - auto padding_type = mlir::RankedTensorType::get( - {static_cast<int64_t>(attr->padding().size())}, op_builder->getI32Type()); + // padding has shape {rank(input_type), 2} + auto padding_type = mlir::RankedTensorType::get({input_type.getRank(), 2}, + op_builder->getI32Type()); + auto padding_attr = + BuildDenseI32ElementsAttr(op_builder, padding_type, attr->padding()); mlir::Operation *mlir_const_op = op_builder->create<mlir::tosa::ConstOp>(loc, padding_type, padding_attr); block->push_back(mlir_const_op); @@ -1024,7 +1027,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>( auto const_type = mlir::RankedTensorType::get( {static_cast<int64_t>(perms_values.size())}, op_builder->getI32Type()); mlir::DenseElementsAttr const_attr = - BuildDenseI32ElementsAttr(op_builder, perms_values); + BuildDenseI32ElementsAttr(op_builder, const_type, perms_values); mlir::Operation *mlir_const_op = op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); auto perms_val = mlir_const_op->getResult(0); |