aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-01-24 11:33:22 -0800
committerTatWai Chong <tatwai.chong@arm.com>2024-01-31 09:49:37 -0800
commitf5f00a9dd8bf6c4ad6256c2332e01662972b3fab (patch)
treefe710b7dbbe5eca621909d6052835196f95a58dd
parent06cb91ba15e860adf72409341143f30613b336c1 (diff)
downloadtosa_mlir_translator-f5f00a9dd8bf6c4ad6256c2332e01662972b3fab.tar.gz
Change the start and size of slice to tosa shape type
Change-Id: Ifcb33a238ae3acdab9c33e039fb05f45aeb4df1c Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--src/TosaDeserialize.cpp8
-rw-r--r--src/TosaSerialize.cpp11
2 files changed, 6 insertions, 13 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 5956bc8..b80e2cb 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1139,13 +1139,9 @@ TosaMlirOperatorBuilder::build<Op_SLICE>(TosaSerializationOperator *op) const {
assert(op->GetAttributeType() ==
Attribute_SliceAttribute); // double check attribute type
- TosaSliceAttribute *attr =
- static_cast<TosaSliceAttribute *>(op->GetAttribute());
- mlir::DenseI64ArrayAttr start =
- BuildDenseI64ArrayAttr(op_builder, attr->start());
- mlir::DenseI64ArrayAttr size =
- BuildDenseI64ArrayAttr(op_builder, attr->size());
+ mlir::Value start = tensor_map->at(op->GetInputTensorNames()[1]);
+ mlir::Value size = tensor_map->at(op->GetInputTensorNames()[2]);
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::SliceOp>(
loc, output_type, input_val, start, size);
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index de301fe..04709b7 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1291,17 +1291,14 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::SliceOp>(
mlir::Operation &op) const {
- auto start = getDenseI64ArrayAttr<int>(op.getAttr("start"));
- auto size = getDenseI64ArrayAttr<int>(op.getAttr("size"));
-
- TosaSliceAttribute attribute(start, size);
-
std::string input_name = GetTensorName(op.getOperand(0));
+ std::string start_name = GetTensorName(op.getOperand(1));
+ std::string size_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(op.getResult(0));
TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_SLICE, Attribute_SliceAttribute, &attribute,
- std::vector<std::string>{input_name},
+ Op_SLICE, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name, start_name, size_name},
std::vector<std::string>{output_name});
return tyop;