From bffbb131352491f2eaaaa7a6aca3f860a4a09f02 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Tue, 27 Feb 2024 22:48:07 -0800 Subject: Cast integer pad value to boolean in boolean pad operation In the type conversion, zero integer is equal to `false`, and any non-zero integner evaluates to `true`. Change-Id: I78f395f8cfa17f77d72c4f18a65d9848c5811ddc Signed-off-by: TatWai Chong --- src/TosaDeserialize.cpp | 81 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 301d0da..5ad7e6e 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1056,39 +1056,64 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; + mlir::Value pad_const_value; + const auto element_type = + input_val.getType().cast().getElementType(); + + bool isBoolType = element_type.isInteger(1); + // First handle boolean type. + if (isBoolType) { + mlir::Type boolType = op_builder->getIntegerType(1); + auto pad_const_type = mlir::RankedTensorType::get({}, boolType); + // Treat zero integer is `false`, and any non-zero integner evaluates to + // `true`. + bool pad_const = pad_const_int == 0 ? false : true; + auto pad_const_attr = + mlir::DenseElementsAttr::get(pad_const_type, {pad_const}); + mlir::Operation *pad_const_op = op_builder->create( + loc, pad_const_type, pad_const_attr); + + block->push_back(pad_const_op); + pad_const_value = pad_const_op->getResult(0); + mlir_op = op_builder->create( + loc, output_type, input_val, padding_val, pad_const_value); + + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); + } + + // Second handle the cases where no explicit pad_const input. if (pad_const_int == 0 && pad_const_fp == 0.0f) { - // no pad_const input mlir_op = op_builder->create(loc, output_type, input_val, padding_val); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); + } + + // Then handle explicit numerical pad_const cases. + if (pad_const_int != 0) { + assert(pad_const_fp == 0.0f && llvm::isa(element_type)); + auto pad_const_int_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_int_attr = + mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); + mlir::Operation *pad_const_int_op = op_builder->create( + loc, pad_const_int_type, pad_const_int_attr); + block->push_back(pad_const_int_op); + pad_const_value = pad_const_int_op->getResult(0); } else { - // create a const value for pad_const input - const auto input_element_type = - input_val.getType().cast().getElementType(); - mlir::Value pad_const_value; - if (pad_const_int != 0) { - auto pad_const_int_type = - mlir::RankedTensorType::get({}, input_element_type); - auto pad_const_int_attr = - mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int}); - mlir::Operation *pad_const_int_op = - op_builder->create(loc, pad_const_int_type, - pad_const_int_attr); - block->push_back(pad_const_int_op); - pad_const_value = pad_const_int_op->getResult(0); - } else if (pad_const_fp != 0) { - auto pad_const_fp_type = - mlir::RankedTensorType::get({}, input_element_type); - auto pad_const_fp_attr = - mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); - mlir::Operation *pad_const_fp_op = - op_builder->create(loc, pad_const_fp_type, - pad_const_fp_attr); - block->push_back(pad_const_fp_op); - pad_const_value = pad_const_fp_op->getResult(0); - } - mlir_op = op_builder->create( - loc, output_type, input_val, padding_val, pad_const_value); + assert(pad_const_fp != 0 && llvm::isa(element_type)); + auto pad_const_fp_type = mlir::RankedTensorType::get({}, element_type); + auto pad_const_fp_attr = + mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp}); + mlir::Operation *pad_const_fp_op = op_builder->create( + loc, pad_const_fp_type, pad_const_fp_attr); + block->push_back(pad_const_fp_op); + pad_const_value = pad_const_fp_op->getResult(0); } + + mlir_op = op_builder->create(loc, output_type, input_val, + padding_val, pad_const_value); + block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); } -- cgit v1.2.1