From 06cb91ba15e860adf72409341143f30613b336c1 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 25 Jan 2024 18:37:27 +0000 Subject: Change PadOp's padding to TOSA Shape PadOp's padding input is now a tosa.shape type. Changed serialization of PadOp to not store padding as attribute. Changed deserialization of PadOp to not restore padding from attribute. Signed-off-by: Tai Ly Change-Id: I8a622978ea184b8d2779d311adba629c1a0d1fbd --- src/TosaDeserialize.cpp | 14 +++----------- src/TosaSerialize.cpp | 21 ++++++--------------- 2 files changed, 9 insertions(+), 26 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 2421d79..5956bc8 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1029,6 +1029,7 @@ template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::Value padding_val = tensor_map->at(op->GetInputTensorNames()[1]); mlir::RankedTensorType input_type = tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = @@ -1038,15 +1039,6 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast(op->GetAttribute()); - // padding has shape {rank(input_type), 2} - auto padding_type = mlir::RankedTensorType::get({input_type.getRank(), 2}, - op_builder->getI32Type()); - auto padding_attr = - BuildDenseI32ElementsAttr(op_builder, padding_type, attr->padding()); - mlir::Operation *mlir_const_op = - op_builder->create(loc, padding_type, padding_attr); - block->push_back(mlir_const_op); - auto padding_value = mlir_const_op->getResult(0); auto pad_const_int = attr->pad_const_int(); auto pad_const_fp = attr->pad_const_fp(); // todo: int input_zp = attr->pad_input_zp(); @@ -1055,7 +1047,7 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { if (pad_const_int == 0 && pad_const_fp == 0.0f) { // no pad_const input mlir_op = op_builder->create(loc, output_type, input_val, - padding_value); + padding_val); } else { // create a const value for pad_const input const auto input_element_type = @@ -1083,7 +1075,7 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { pad_const_value = pad_const_fp_op->getResult(0); } mlir_op = op_builder->create( - loc, output_type, input_val, padding_value, pad_const_value); + loc, output_type, input_val, padding_val, pad_const_value); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 4f2c358..de301fe 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1205,19 +1205,10 @@ TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); + std::string padding_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); auto pad_op = llvm::cast(op); - // Match padding tensor as compile-time constant attribute - mlir::ElementsAttr paddings_elems; - if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings_elems))) - return nullptr; - - std::vector paddings; - for (int32_t val : paddings_elems.getValues()) { - paddings.push_back(val); - } - auto quant_info = pad_op.getQuantizationInfoAttr(); // pad_const includes the zero point if the tensor uses a zero point. int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; @@ -1242,12 +1233,12 @@ TosaSerializationOperatorBuilder::build( } } - TosaPadAttribute attribute(paddings, pad_const_int, pad_const_fp); + TosaPadAttribute attribute({}, pad_const_int, pad_const_fp); - TosaSerializationOperator *tyop = - new TosaSerializationOperator(Op_PAD, Attribute_PadAttribute, &attribute, - std::vector{input_name}, - std::vector{output_name}); + TosaSerializationOperator *tyop = new TosaSerializationOperator( + Op_PAD, Attribute_PadAttribute, &attribute, + std::vector{input_name, padding_name}, + std::vector{output_name}); return tyop; } -- cgit v1.2.1