diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 81 |
1 files changed, 53 insertions, 28 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 301d0da..5ad7e6e 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1056,39 +1056,64 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; + mlir::Value pad_const_value; + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + + bool isBoolType = element_type.isInteger(1); + // First handle boolean type. + if (isBoolType) { + mlir::Type boolType = op_builder->getIntegerType(1); + auto pad_const_type = mlir::RankedTensorType::get({}, boolType); + // Treat zero integer is `false`, and any non-zero integner evaluates to + // `true`. + bool pad_const = pad_const_int == 0 ? false : true; + auto pad_const_attr = + mlir::DenseElementsAttr::get(pad_const_type, {pad_const}); + mlir::Operation *pad_const_op = op_builder->create<mlir::tosa::ConstOp>( + loc, pad_const_type, pad_const_attr); + + block->push_back(pad_const_op); + pad_const_value = pad_const_op->getResult(0); + mlir_op = op_builder->create<mlir::tosa::PadOp>( + loc, output_type, input_val, padding_val, pad_const_value); + + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); + } + + // Second handle the cases where no explicit pad_const input. if (pad_const_int == 0 && pad_const_fp == 0.0f) { - // no pad_const input mlir_op = op_builder->create<mlir::tosa::PadOp>(loc, output_type, input_val, padding_val); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); + } + + // Then handle explicit numerical pad_const cases. + if (pad_const_int != 0) { + assert(pad_const_fp == 0.0f && llvm::isa<IntegerType>(element_type)); + auto pad_const_int_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_int_attr = + mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); + mlir::Operation *pad_const_int_op = op_builder->create<mlir::tosa::ConstOp>( + loc, pad_const_int_type, pad_const_int_attr); + block->push_back(pad_const_int_op); + pad_const_value = pad_const_int_op->getResult(0); } else { - // create a const value for pad_const input - const auto input_element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); - mlir::Value pad_const_value; - if (pad_const_int != 0) { - auto pad_const_int_type = - mlir::RankedTensorType::get({}, input_element_type); - auto pad_const_int_attr = - mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); - mlir::Operation *pad_const_int_op = - op_builder->create<mlir::tosa::ConstOp>(loc, pad_const_int_type, - pad_const_int_attr); - block->push_back(pad_const_int_op); - pad_const_value = pad_const_int_op->getResult(0); - } else if (pad_const_fp != 0) { - auto pad_const_fp_type = - mlir::RankedTensorType::get({}, input_element_type); - auto pad_const_fp_attr = - mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); - mlir::Operation *pad_const_fp_op = - op_builder->create<mlir::tosa::ConstOp>(loc, pad_const_fp_type, - pad_const_fp_attr); - block->push_back(pad_const_fp_op); - pad_const_value = pad_const_fp_op->getResult(0); - } - mlir_op = op_builder->create<mlir::tosa::PadOp>( - loc, output_type, input_val, padding_val, pad_const_value); + assert(pad_const_fp != 0 && llvm::isa<FloatType>(element_type)); + auto pad_const_fp_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_fp_attr = + mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); + mlir::Operation *pad_const_fp_op = op_builder->create<mlir::tosa::ConstOp>( + loc, pad_const_fp_type, pad_const_fp_attr); + block->push_back(pad_const_fp_op); + pad_const_value = pad_const_fp_op->getResult(0); } + + mlir_op = op_builder->create<mlir::tosa::PadOp>(loc, output_type, input_val, + padding_val, pad_const_value); + block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } |