aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaDeserialize.cpp63
-rw-r--r--src/TosaSerialize.cpp22
m---------third_party/serialization_lib0
3 files changed, 45 insertions, 40 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 215d760..68989c2 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1303,36 +1303,51 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
- assert(op->GetAttributeType() ==
- Attribute_TableAttribute); // double check attribute type
- TosaTableAttribute *attr =
- static_cast<TosaTableAttribute *>(op->GetAttribute());
-
- // create a const op for table value attribute
- const auto table_values = attr->table();
- mlir::RankedTensorType const_type;
- mlir::DenseElementsAttr const_attr;
- const auto input_element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
- if (input_element_type.isInteger(8)) {
- // table is signed 8 mode
- const_type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
- const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ mlir::Value table_value;
+ if (op->GetInputTensorNames().size() > 1 &&
+ tensor_map->find(op->GetInputTensorNames()[1]) != tensor_map->end()) {
+ table_value = tensor_map->at(op->GetInputTensorNames()[1]);
} else {
- // table is signed 16 mode
- const_type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type());
- const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ // Backward compatible support with the version <= 1.0 rc in which the data
+ // of table operand is embedded in the operation as attribute type.
+ // TODO drop this supporting code when the table attribute is no longer
+ // supported.
+ TosaTableAttribute *attr =
+ dynamic_cast<TosaTableAttribute *>(op->GetAttribute());
+
+ if (attr == nullptr)
+ llvm::errs() << "the table attribute is not found\n";
+
+ const auto table_values = attr->table();
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())},
+ op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
+
+ // Create a const op for table value attribute.
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ block->push_back(mlir_const_op);
+
+ table_value = mlir_const_op->getResult(0);
}
- mlir::Operation *mlir_const_op =
- op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
- auto table_value = mlir_const_op->getResult(0);
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>(
loc, output_type, input_val, table_value);
- block->push_back(mlir_const_op);
block->push_back(mlir_op);
+
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index c3a9878..90645c9 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1402,24 +1402,14 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
mlir::Operation &op) const {
- std::string input_name = GetTensorName(op.getOperand(0));
- std::string output_name = GetTensorName(op.getResult(0));
-
- // Match table tensor as compile-time constant attribute
- mlir::ElementsAttr table_elems;
- if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
- return nullptr;
-
- std::vector<int16_t> table;
- for (auto value : table_elems.getValues<mlir::IntegerAttr>()) {
- table.push_back(value.getInt());
- }
-
- TosaTableAttribute attribute(table);
+ auto table_op = mlir::cast<mlir::tosa::TableOp>(op);
+ std::string input_name = GetTensorName(table_op.getInput());
+ std::string output_name = GetTensorName(table_op.getOutput());
+ std::string table_name = GetTensorName(table_op.getTable());
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_TABLE, Attribute_TableAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_TABLE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, table_name},
std::vector<std::string>{output_name});
return tyop;
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 36ced1df313cf80edc91efe41facb1ab3a81b22
+Subproject 3aebe2bd863d6e0cb82171984cd49e5ad516d0d