aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp31
1 files changed, 27 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d55585e..db52d67 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -99,6 +99,18 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
}
template <class T>
+mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type =
+ mlir::RankedTensorType::get({vec.size()}, op_builder->getI8Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
mlir::DenseElementsAttr
BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
const std::vector<T> &values) {
@@ -997,10 +1009,21 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
// create a const op for table value attribute
const auto table_values = attr->table();
- auto const_type = mlir::RankedTensorType::get({table_values.size()},
- op_builder->getI16Type());
- mlir::DenseElementsAttr const_attr =
- BuildDenseI16ElementsAttr(op_builder, table_values);
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get({table_values.size()},
+ op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get({table_values.size()},
+ op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
mlir::Operation *mlir_const_op =
op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
auto table_value = mlir_const_op->getResult(0);