diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 63 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 22 | ||||
m--------- | third_party/serialization_lib | 0 |
3 files changed, 45 insertions, 40 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 215d760..68989c2 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1303,36 +1303,51 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); - assert(op->GetAttributeType() == - Attribute_TableAttribute); // double check attribute type - TosaTableAttribute *attr = - static_cast<TosaTableAttribute *>(op->GetAttribute()); - - // create a const op for table value attribute - const auto table_values = attr->table(); - mlir::RankedTensorType const_type; - mlir::DenseElementsAttr const_attr; - const auto input_element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); - if (input_element_type.isInteger(8)) { - // table is signed 8 mode - const_type = mlir::RankedTensorType::get( - {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); - const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + mlir::Value table_value; + if (op->GetInputTensorNames().size() > 1 && + tensor_map->find(op->GetInputTensorNames()[1]) != tensor_map->end()) { + table_value = tensor_map->at(op->GetInputTensorNames()[1]); } else { - // table is signed 16 mode - const_type = mlir::RankedTensorType::get( - {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type()); - const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + // Backward compatible support with the version <= 1.0 rc in which the data + // of table operand is embedded in the operation as attribute type. + // TODO drop this supporting code when the table attribute is no longer + // supported. + TosaTableAttribute *attr = + dynamic_cast<TosaTableAttribute *>(op->GetAttribute()); + + if (attr == nullptr) + llvm::errs() << "the table attribute is not found\n"; + + const auto table_values = attr->table(); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get( + {static_cast<int64_t>(table_values.size())}, + op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } + + // Create a const op for table value attribute. + mlir::Operation *mlir_const_op = + op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); + block->push_back(mlir_const_op); + + table_value = mlir_const_op->getResult(0); } - mlir::Operation *mlir_const_op = - op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); - auto table_value = mlir_const_op->getResult(0); mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>( loc, output_type, input_val, table_value); - block->push_back(mlir_const_op); block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index c3a9878..90645c9 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1402,24 +1402,14 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>( mlir::Operation &op) const { - std::string input_name = GetTensorName(op.getOperand(0)); - std::string output_name = GetTensorName(op.getResult(0)); - - // Match table tensor as compile-time constant attribute - mlir::ElementsAttr table_elems; - if (!matchPattern(op.getOperand(1), m_Constant(&table_elems))) - return nullptr; - - std::vector<int16_t> table; - for (auto value : table_elems.getValues<mlir::IntegerAttr>()) { - table.push_back(value.getInt()); - } - - TosaTableAttribute attribute(table); + auto table_op = mlir::cast<mlir::tosa::TableOp>(op); + std::string input_name = GetTensorName(table_op.getInput()); + std::string output_name = GetTensorName(table_op.getOutput()); + std::string table_name = GetTensorName(table_op.getTable()); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TABLE, Attribute_TableAttribute, &attribute, - std::vector<std::string>{input_name}, + Op_TABLE, Attribute_NONE, nullptr, + std::vector<std::string>{input_name, table_name}, std::vector<std::string>{output_name}); return tyop; diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 36ced1df313cf80edc91efe41facb1ab3a81b22 +Subproject 3aebe2bd863d6e0cb82171984cd49e5ad516d0d |