diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 14 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 27 |
2 files changed, 28 insertions, 13 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 87c363f..5704b04 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1234,12 +1234,18 @@ TosaMlirOperatorBuilder::build<Op_MUL>(TosaSerializationOperator *op) const { assert(op->GetAttributeType() == Attribute_MulAttribute); // double check attribute type - TosaMulAttribute *attr = static_cast<TosaMulAttribute *>(op->GetAttribute()); - auto shift = op_builder->getI8IntegerAttr(attr->shift()); + mlir::ValueRange operands; + if (output_type.getElementType().isInteger(32)) { + // Integer multiply carries shift argument. + mlir::Value shift_val = tensor_map->at(op->GetInputTensorNames()[2]); + operands = {input0_val, input1_val, shift_val}; + } else { + operands = {input0_val, input1_val}; + } - mlir::Operation *mlir_op = op_builder->create<mlir::tosa::MulOp>( - loc, output_type, input0_val, input1_val, shift); + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::MulOp>(loc, output_type, operands); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index fc6655b..05c7812 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1272,18 +1272,27 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::MulOp>( mlir::Operation &op) const { - std::string input0_name = GetTensorName(op.getOperand(0)); - std::string input1_name = GetTensorName(op.getOperand(1)); - std::string output_name = GetTensorName(op.getResult(0)); - int32_t shift = op.getAttr("shift").dyn_cast<mlir::IntegerAttr>().getInt(); + mlir::tosa::MulOp mul_op = mlir::cast<mlir::tosa::MulOp>(op); + std::string input0_name = GetTensorName(mul_op.getInput1()); + std::string input1_name = GetTensorName(mul_op.getInput2()); + std::string output_name = GetTensorName(mul_op.getOutput()); - TosaMulAttribute attribute(shift); + std::vector<std::string> operands; + if (mul_op.getOutput() + .getType() + .cast<mlir::TensorType>() + .getElementType() + .isInteger(32)) { + std::string shift_name = GetTensorName(mul_op.getShift()); + operands = {input0_name, input1_name, shift_name}; + } else { + operands = {input0_name, input1_name}; + } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, - std::vector<std::string>{input0_name, input1_name}, - std::vector<std::string>{output_name}); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_MUL, Attribute_NONE, nullptr, operands, + std::vector<std::string>{output_name}); return tyop; } |