aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp27
1 files changed, 18 insertions, 9 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index fc6655b..05c7812 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1272,18 +1272,27 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>(
mlir::Operation &op) const {
- std::string input0_name = GetTensorName(op.getOperand(0));
- std::string input1_name = GetTensorName(op.getOperand(1));
- std::string output_name = GetTensorName(op.getResult(0));
- int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt();
+ mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op);
+ std::string input0_name = GetTensorName(mul_op.getInput1());
+ std::string input1_name = GetTensorName(mul_op.getInput2());
+ std::string output_name = GetTensorName(mul_op.getOutput());
- TosaMulAttribute attribute(shift);
+ std::vector<std::string> operands;
+ if (mul_op.getOutput()
+ .getType()
+ .cast<mlir::TensorType>()
+ .getElementType()
+ .isInteger(32)) {
+ std::string shift_name = GetTensorName(mul_op.getShift());
+ operands = {input0_name, input1_name, shift_name};
+ } else {
+ operands = {input0_name, input1_name};
+ }
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_MUL, Attribute_MulAttribute, &attribute,
- std::vector<std::string>{input0_name, input1_name},
- std::vector<std::string>{output_name});
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands,
+ std::vector<std::string>{output_name});
return tyop;
}