diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2024-02-08 13:54:21 -0800 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-02-22 02:08:07 +0000 |
commit | 4bc57740f03704179d2611f5d41572612bc42e9a (patch) | |
tree | 39055cafd2e2fa2fb2aa1939da8741c1355d91a0 /src/TosaSerialize.cpp | |
parent | 86db8bc37237c68a30a917ff77cbcd7784879ae4 (diff) | |
download | tosa_mlir_translator-4bc57740f03704179d2611f5d41572612bc42e9a.tar.gz |
Change the shift of mul to the tensor type
Note that the shift value only apply to i32 data tensor.
Change-Id: I4368146d9e0cdc2243ff822f59489ef5b78148ec
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index fc6655b..05c7812 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1272,18 +1272,27 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>( mlir::Operation &op) const { - std::string input0_name = GetTensorName(op.getOperand(0)); - std::string input1_name = GetTensorName(op.getOperand(1)); - std::string output_name = GetTensorName(op.getResult(0)); - int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt(); + mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op); + std::string input0_name = GetTensorName(mul_op.getInput1()); + std::string input1_name = GetTensorName(mul_op.getInput2()); + std::string output_name = GetTensorName(mul_op.getOutput()); - TosaMulAttribute attribute(shift); + std::vector<std::string> operands; + if (mul_op.getOutput() + .getType() + .cast<mlir::TensorType>() + .getElementType() + .isInteger(32)) { + std::string shift_name = GetTensorName(mul_op.getShift()); + operands = {input0_name, input1_name, shift_name}; + } else { + operands = {input0_name, input1_name}; + } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, - std::vector<std::string>{input0_name, input1_name}, - std::vector<std::string>{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands, + std::vector<std::string>{output_name}); return tyop; } |