aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp28
1 files changed, 21 insertions, 7 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 55a11fd..875303e 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -105,14 +105,16 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-static DType Type2PoolAccDType(mlir::Type element_type) {
- // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>;
+static DType Type2AccDType(mlir::Type element_type) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
if (element_type.isF32()) {
return DType_FP32;
} else if (element_type.isF16()) {
return DType_FP16;
} else if (element_type.isInteger(32)) {
return DType_INT32;
+ } else if (element_type.isInteger(48)) {
+ return DType_INT48;
}
return DType_UNKNOWN;
}
@@ -465,7 +467,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op,
// AvgPool has acc_type, MaxPool does not
if (op.hasAttr("acc_type")) {
auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
- acc_dtype = Type2PoolAccDType(acc_type);
+ acc_dtype = Type2AccDType(acc_type);
}
std::string input_name = GetTensorName(op.getOperand(0));
@@ -735,8 +737,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -778,8 +783,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV3D, Attribute_ConvAttribute, &attribute,
@@ -821,8 +829,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp,
- local_bound);
+ local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -867,8 +878,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue()
: false;
+ auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue();
+ auto acc_dtype = Type2AccDType(acc_type);
+
TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp,
- weight_zp, local_bound);
+ weight_zp, local_bound, acc_dtype);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,