aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-02-08 13:54:21 -0800
committerEric Kunze <eric.kunze@arm.com>2024-02-22 02:08:07 +0000
commit4bc57740f03704179d2611f5d41572612bc42e9a (patch)
tree39055cafd2e2fa2fb2aa1939da8741c1355d91a0 /src/TosaSerialize.cpp
parent86db8bc37237c68a30a917ff77cbcd7784879ae4 (diff)
downloadtosa_mlir_translator-4bc57740f03704179d2611f5d41572612bc42e9a.tar.gz
Change the shift of mul to the tensor type
Note that the shift value only apply to i32 data tensor. Change-Id: I4368146d9e0cdc2243ff822f59489ef5b78148ec Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp27
1 files changed, 18 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index fc6655b..05c7812 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1272,18 +1272,27 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
mlir::Operation &op) const {
- std::string input0_name = GetTensorName(op.getOperand(0));
- std::string input1_name = GetTensorName(op.getOperand(1));
- std::string output_name = GetTensorName(op.getResult(0));
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
+ mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op);
+ std::string input0_name = GetTensorName(mul_op.getInput1());
+ std::string input1_name = GetTensorName(mul_op.getInput2());
+ std::string output_name = GetTensorName(mul_op.getOutput());
- TosaMulAttribute attribute(shift);
+ std::vector<std::string> operands;
+ if (mul_op.getOutput()
+ .getType()
+ .cast<mlir::TensorType>()
+ .getElementType()
+ .isInteger(32)) {
+ std::string shift_name = GetTensorName(mul_op.getShift());
+ operands = {input0_name, input1_name, shift_name};
+ } else {
+ operands = {input0_name, input1_name};
+ }
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_MUL, Attribute_MulAttribute, &attribute,
- std::vector<std::string>{input0_name, input1_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands,
+ std::vector<std::string>{output_name});
return tyop;
}