diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2024-05-12 02:39:22 -0700 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2024-05-28 09:29:35 -0700 |
commit | 2bdd25e453e35dcec105f6c7aea60fc2dbd52f1b (patch) | |
tree | 0aecbb6ff2d508faf19d9e8664a5b3c6d4d6aef7 /src | |
parent | 82d08b663230c79ebfefb608f8c7b88dc63b0ea6 (diff) | |
download | tosa_mlir_translator-main.tar.gz |
In deserialization, support backward compatible with the previous
versions in which the data of table operand is embedded in the
operation as attribute type.
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: I016ac5890edc0d3dc742ebabf88cc92dee83610d
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaDeserialize.cpp | 63 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 22 |
2 files changed, 45 insertions, 40 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 215d760..68989c2 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1303,36 +1303,51 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); - assert(op->GetAttributeType() == - Attribute_TableAttribute); // double check attribute type - TosaTableAttribute *attr = - static_cast<TosaTableAttribute *>(op->GetAttribute()); - - // create a const op for table value attribute - const auto table_values = attr->table(); - mlir::RankedTensorType const_type; - mlir::DenseElementsAttr const_attr; - const auto input_element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); - if (input_element_type.isInteger(8)) { - // table is signed 8 mode - const_type = mlir::RankedTensorType::get( - {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); - const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + mlir::Value table_value; + if (op->GetInputTensorNames().size() > 1 && + tensor_map->find(op->GetInputTensorNames()[1]) != tensor_map->end()) { + table_value = tensor_map->at(op->GetInputTensorNames()[1]); } else { - // table is signed 16 mode - const_type = mlir::RankedTensorType::get( - {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type()); - const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + // Backward compatible support with the version <= 1.0 rc in which the data + // of table operand is embedded in the operation as attribute type. + // TODO drop this supporting code when the table attribute is no longer + // supported. + TosaTableAttribute *attr = + dynamic_cast<TosaTableAttribute *>(op->GetAttribute()); + + if (attr == nullptr) + llvm::errs() << "the table attribute is not found\n"; + + const auto table_values = attr->table(); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, + op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } + + // Create a const op for table value attribute. + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + block->push_back(mlir_const_op); + + table_value = mlir_const_op->getResult(0); } - mlir::Operation *mlir_const_op = - op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); - auto table_value = mlir_const_op->getResult(0); mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>( loc, output_type, input_val, table_value); - block->push_back(mlir_const_op); block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index c3a9878..90645c9 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1402,24 +1402,14 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>( mlir::Operation &op) const { - std::string input_name = GetTensorName(op.getOperand(0)); - std::string output_name = GetTensorName(op.getResult(0)); - - // Match table tensor as compile-time constant attribute - mlir::ElementsAttr table_elems; - if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) - return nullptr; - - std::vector<int16_t> table; - for (auto value : table_elems.getValues<mlir::IntegerAttr>()) { - table.push_back(value.getInt()); - } - - TosaTableAttribute attribute(table); + auto table_op = mlir::cast<mlir::tosa::TableOp>(op); + std::string input_name = GetTensorName(table_op.getInput()); + std::string output_name = GetTensorName(table_op.getOutput()); + std::string table_name = GetTensorName(table_op.getTable()); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TABLE, Attribute_TableAttribute, &attribute, - std::vector<std::string>{input_name}, + Op_TABLE, Attribute_NONE, nullptr, + std::vector<std::string>{input_name, table_name}, std::vector<std::string>{output_name}); return tyop; |