diff options
author | Tai Ly <tai.ly@arm.com> | 2023-07-07 19:37:52 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-07-10 18:05:15 +0000 |
commit | 289fd33b43f5260dba2faab9c372e42b4b71f83d (patch) | |
tree | 47864a5332f79cec04fe0ca53914f33b959ee2b9 | |
parent | 9121c479ef2b11b52d7e77dc1b4dccd8f55b0db0 (diff) | |
download | tosa_mlir_translator-289fd33b43f5260dba2faab9c372e42b4b71f83d.tar.gz |
Fix deserialization of Table const type
Fixed dDeserialization of Table const to have either I8
or I16 element types based on input element type
eventhough table const is always serialized as int16_t values.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I5d9972ba6a97a709003d73741c91bab31be06cf7
-rw-r--r-- | src/TosaDeserialize.cpp | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d55585e..db52d67 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -99,6 +99,18 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, } template <class T> +mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = + mlir::RankedTensorType::get({vec.size()}, op_builder->getI8Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> mlir::DenseElementsAttr BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, const std::vector<T> &values) { @@ -997,10 +1009,21 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { // create a const op for table value attribute const auto table_values = attr->table(); - auto const_type = mlir::RankedTensorType::get({table_values.size()}, - op_builder->getI16Type()); - mlir::DenseElementsAttr const_attr = - BuildDenseI16ElementsAttr(op_builder, table_values); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get({table_values.size()}, + op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get({table_values.size()}, + op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } mlir::Operation *mlir_const_op = op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); auto table_value = mlir_const_op->getResult(0); |