aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-07-07 19:37:52 +0000
committerTai Ly <tai.ly@arm.com>2023-07-10 18:05:15 +0000
commit289fd33b43f5260dba2faab9c372e42b4b71f83d (patch)
tree47864a5332f79cec04fe0ca53914f33b959ee2b9
parent9121c479ef2b11b52d7e77dc1b4dccd8f55b0db0 (diff)
downloadtosa_mlir_translator-289fd33b43f5260dba2faab9c372e42b4b71f83d.tar.gz
Fix deserialization of Table const type
Fixed dDeserialization of Table const to have either I8 or I16 element types based on input element type eventhough table const is always serialized as int16_t values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I5d9972ba6a97a709003d73741c91bab31be06cf7
-rw-r--r--src/TosaDeserialize.cpp31
1 files changed, 27 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d55585e..db52d67 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -99,6 +99,18 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
}
template <class T>
+mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder,
+ const std::vector<T> &values) {
+ std::vector<int8_t> vec;
+ for (auto val : values) {
+ vec.push_back(val);
+ }
+ auto type =
+ mlir::RankedTensorType::get({vec.size()}, op_builder->getI8Type());
+ return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec));
+}
+
+template <class T>
mlir::DenseElementsAttr
BuildDenseI16ElementsAttr(mlir::OpBuilder *op_builder,
const std::vector<T> &values) {
@@ -997,10 +1009,21 @@ TosaMlirOperatorBuilder::build<Op_TABLE>(TosaSerializationOperator *op) const {
// create a const op for table value attribute
const auto table_values = attr->table();
- auto const_type = mlir::RankedTensorType::get({table_values.size()},
- op_builder->getI16Type());
- mlir::DenseElementsAttr const_attr =
- BuildDenseI16ElementsAttr(op_builder, table_values);
+ mlir::RankedTensorType const_type;
+ mlir::DenseElementsAttr const_attr;
+ const auto input_element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
+ if (input_element_type.isInteger(8)) {
+ // table is signed 8 mode
+ const_type = mlir::RankedTensorType::get({table_values.size()},
+ op_builder->getI8Type());
+ const_attr = BuildDenseI8ElementsAttr(op_builder, table_values);
+ } else {
+ // table is signed 16 mode
+ const_type = mlir::RankedTensorType::get({table_values.size()},
+ op_builder->getI16Type());
+ const_attr = BuildDenseI16ElementsAttr(op_builder, table_values);
+ }
mlir::Operation *mlir_const_op =
op_builder->create<mlir::tosa::ConstOp>(loc, const_type, const_attr);
auto table_value = mlir_const_op->getResult(0);