From 4bc57740f03704179d2611f5d41572612bc42e9a Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 8 Feb 2024 13:54:21 -0800 Subject: Change the shift of mul to the tensor type Note that the shift value only apply to i32 data tensor. Change-Id: I4368146d9e0cdc2243ff822f59489ef5b78148ec Signed-off-by: TatWai Chong --- src/TosaSerialize.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index fc6655b..05c7812 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1272,18 +1272,27 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { - std::string input0_name = GetTensorName(op.getOperand(0)); - std::string input1_name = GetTensorName(op.getOperand(1)); - std::string output_name = GetTensorName(op.getResult(0)); - int32_t shift = op.getAttr("shift").dyn_cast().getInt(); + mlir::tosa::MulOp mul_op = mlir::cast(op); + std::string input0_name = GetTensorName(mul_op.getInput1()); + std::string input1_name = GetTensorName(mul_op.getInput2()); + std::string output_name = GetTensorName(mul_op.getOutput()); - TosaMulAttribute attribute(shift); + std::vector operands; + if (mul_op.getOutput() + .getType() + .cast() + .getElementType() + .isInteger(32)) { + std::string shift_name = GetTensorName(mul_op.getShift()); + operands = {input0_name, input1_name, shift_name}; + } else { + operands = {input0_name, input1_name}; + } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, - std::vector{input0_name, input1_name}, - std::vector{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands, + std::vector{output_name}); return tyop; } -- cgit v1.2.1