aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 70a19a7..9e646c7 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -87,6 +87,9 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
case DType_UINT16:
element_type = op_builder->getIntegerType(16, false);
break;
+ case DType_SHAPE:
+ element_type = op_builder->getIntegerType(64);
+ break;
default:
llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()]
<< "\n";
@@ -860,6 +863,25 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
}
template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(
TosaSerializationOperator *op) const {
mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);