diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d55585e..db52d67 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -99,6 +99,18 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, } template <class T> +mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, + const std::vector<T> &values) { + std::vector<int8_t> vec; + for (auto val : values) { + vec.push_back(val); + } + auto type = + mlir::RankedTensorType::get({vec.size()}, op_builder->getI8Type()); + return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec)); +} + +template <class T> mlir::DenseElementsAttr BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder, const std::vector<T> &values) { @@ -997,10 +1009,21 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const { // create a const op for table value attribute const auto table_values = attr->table(); - auto const_type = mlir::RankedTensorType::get({table_values.size()}, - op_builder->getI16Type()); - mlir::DenseElementsAttr const_attr = - BuildDenseI16ElementsAttr(op_builder, table_values); + mlir::RankedTensorType const_type; + mlir::DenseElementsAttr const_attr; + const auto input_element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); + if (input_element_type.isInteger(8)) { + // table is signed 8 mode + const_type = mlir::RankedTensorType::get({table_values.size()}, + op_builder->getI8Type()); + const_attr = BuildDenseI8ElementsAttr(op_builder, table_values); + } else { + // table is signed 16 mode + const_type = mlir::RankedTensorType::get({table_values.size()}, + op_builder->getI16Type()); + const_attr = BuildDenseI16ElementsAttr(op_builder, table_values); + } mlir::Operation *mlir_const_op = op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr); auto table_value = mlir_const_op->getResult(0); |