aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-05-12 02:39:22 -0700
committerTatWai Chong <tatwai.chong@arm.com>2024-05-28 09:29:35 -0700
commit2bdd25e453e35dcec105f6c7aea60fc2dbd52f1b (patch)
tree0aecbb6ff2d508faf19d9e8664a5b3c6d4d6aef7 /src
parent82d08b663230c79ebfefb608f8c7b88dc63b0ea6 (diff)
downloadtosa_mlir_translator-main.tar.gz
Change the table parameter from attribute to tensor typeHEADmain
In deserialization, support backward compatible with the previous versions in which the data of table operand is embedded in the operation as attribute type. Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: I016ac5890edc0d3dc742ebabf88cc92dee83610d
Diffstat (limited to 'src')
-rw-r--r--src/TosaDeserialize.cpp63
-rw-r--r--src/TosaSerialize.cpp22
2 files changed, 45 insertions, 40 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 215d760..68989c2 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1303,36 +1303,51 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
- assert(op->GetAttributeType() ==
- Attribute_TableAttribute); // double check attribute type
- TosaTableAttribute *attr =
- static_cast<TosaTableAttribute *>(op->GetAttribute());
-
- // create a const op for table value attribute
- const auto table_values = attr->table();
- mlir::RankedTensorType const_type;
- mlir::DenseElementsAttr const_attr;
- const auto input_element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
- if (input_element_type.isInteger(8)) {
- // table is signed 8 mode
- const_type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
- const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ mlir::Value table_value;
+ if (op->GetInputTensorNames().size() > 1 &&
+ tensor_map->find(op->GetInputTensorNames()[1]) != tensor_map->end()) {
+ table_value = tensor_map->at(op->GetInputTensorNames()[1]);
} else {
- // table is signed 16 mode
- const_type = mlir::RankedTensorType::get(
- {static_cast<int64_t>(table_values.size())}, op_builder->getI16Type());
- const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ // Backward compatible support with the version <= 1.0 rc in which the data
+ // of table operand is embedded in the operation as attribute type.
+ // TODO drop this supporting code when the table attribute is no longer
+ // supported.
+ TosaTableAttribute *attr =
+ dynamic_cast<TosaTableAttribute *>(op->GetAttribute());
+
+ if (attr == nullptr)
+ llvm::errs() << "the table attribute is not found\n";
+
+ const auto table_values = attr->table();
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())}, op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(table_values.size())},
+ op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
+
+ // Create a const op for table value attribute.
+ mlir::Operation *mlir_const_op =
+ op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
+ block->push_back(mlir_const_op);
+
+ table_value = mlir_const_op->getResult(0);
}
- mlir::Operation *mlir_const_op =
- op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
- auto table_value = mlir_const_op->getResult(0);
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::TableOp>(
loc, output_type, input_val, table_value);
- block->push_back(mlir_const_op);
block->push_back(mlir_op);
+
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index c3a9878..90645c9 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1402,24 +1402,14 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TableOp>(
mlir::Operation &op) const {
- std::string input_name = GetTensorName(op.getOperand(0));
- std::string output_name = GetTensorName(op.getResult(0));
-
- // Match table tensor as compile-time constant attribute
- mlir::ElementsAttr table_elems;
- if (!matchPattern(op.getOperand(1), m_Constant(&table_elems)))
- return nullptr;
-
- std::vector<int16_t> table;
- for (auto value : table_elems.getValues<mlir::IntegerAttr>()) {
- table.push_back(value.getInt());
- }
-
- TosaTableAttribute attribute(table);
+ auto table_op = mlir::cast<mlir::tosa::TableOp>(op);
+ std::string input_name = GetTensorName(table_op.getInput());
+ std::string output_name = GetTensorName(table_op.getOutput());
+ std::string table_name = GetTensorName(table_op.getTable());
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_TABLE, Attribute_TableAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_TABLE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, table_name},
std::vector<std::string>{output_name});
return tyop;