From edc04e0ef1a93d1bf4578308d620dd49c10e1ad4 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Mon, 19 Sep 2022 21:07:38 -0700 Subject: [tosa_mlir_translator] Support Tosa StatefulOps Signed-off-by: Jerry Ge Change-Id: I6ce5a917cada436f6a80e6d85e670d6cd44e01e9 --- src/TosaSerialize.cpp | 357 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 328 insertions(+), 29 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index f74df1d..941a75e 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -21,9 +21,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tosa_serialization_handler.h" +#include #include #include #include @@ -77,7 +81,7 @@ static DType Type2DType(mlir::Type element_type) { } else if (element_type.isUnsignedInteger(8)) { return DType_UINT8; } else if (element_type.isInteger(4)) { - return DType_INT8; + return DType_INT4; } else if (element_type.isInteger(8)) { return DType_INT8; } else if (element_type.isInteger(16)) { @@ -97,16 +101,27 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2AccumDType(mlir::Type element_type) { - if (element_type.isF64() || element_type.isF32() || element_type.isF16() || - element_type.isBF16()) { - return DType_FP32; - } else if (element_type.isInteger(8)) { - return DType_INT32; - } else if (element_type.isInteger(16)) { - return DType_INT48; +// Returns number of bits TOSA flatbuffer store in tensor raw data array +uint64_t GetDTypeSize(DType dtype) { + switch (dtype) { + case DType_INT4: + return 4; + case DType_BOOL: + case DType_UINT8: + case DType_INT8: + return 8; + case DType_INT16: + return 16; + case DType_FP32: + case DType_INT32: + return 32; + case DType_INT48: + return 48; + default: + llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype] + << "\n"; + return 1; } - return DType_UNKNOWN; } static DType Type2PoolAccumDType(mlir::Type element_type) { @@ -123,6 +138,26 @@ static DType Type2PoolAccumDType(mlir::Type element_type) { class TosaSerializationBlockBuilder; class TosaSerializationRegionBuilder; +std::unordered_map variable_tensor_op_map; +std::unordered_map + variable_tensor_flatbuffer_name_map; +static int variable_tensor_index = 0; + +namespace { + +// for now, this is a global map of variables +void RegisterVariableOp(mlir::Operation &op) { + std::string variable_tensor_flatbuffer_name = + "Variable_" + std::to_string(variable_tensor_index++); + std::string variable_tensor_mlir_name = + op.getAttr("name").cast().getValue().str(); + variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op; + variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] = + variable_tensor_flatbuffer_name; +} + +} // namespace + class TosaSerializationOperatorBuilder { public: TosaSerializationOperatorBuilder( @@ -133,9 +168,12 @@ public: TosaSerializationOperator *build(mlir::Operation &op) const; TosaSerializationHandler *GetTsh() const; TosaSerializationRegionBuilder *GetRegionBuilder() const; + mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype, + std::vector &u8_data) const; private: std::string GetTensorName(mlir::Value val) const; + std::string GetVariableTensorName(mlir::Operation *op) const; TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op, Op opcode) const; TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op, @@ -171,6 +209,9 @@ private: TosaSerializationOperator *BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * + BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, + const std::string &name); + TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); TosaSerializationBasicBlock *ser_block; @@ -262,6 +303,124 @@ static std::vector getDenseI8ArrayAttr(mlir::Attribute attr) { return vec; } +std::string TosaSerializationOperatorBuilder::GetVariableTensorName( + mlir::Operation *op) const { + mlir::Attribute variable_op_name_attr = op->getAttr("name"); + std::string variable_tensor_mlir_name = + op->getAttr("name").cast().getValue().str(); + + if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) == + variable_tensor_flatbuffer_name_map.end()) { + llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name + << " from variable_tensor_flatbuffer_name_map\n"; + assert(0); + } + return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name]; +} + +mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( + mlir::Attribute &attr, DType type, std::vector &u8_data) const { + auto dense_attr = attr.dyn_cast(); + if (type == DType_FP32) { + std::vector data; + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back((float)val_attr.getValueAsDouble()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } else if (type == DType_INT8) { + std::vector data; + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + TosaSerializationHandler::ConvertI8toU8(data, u8_data); + } else if (type == DType_INT16) { + std::vector data; + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + TosaSerializationHandler::ConvertI16toU8(data, u8_data); + } else if (type == DType_INT32) { + std::vector data; + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + TosaSerializationHandler::ConvertI32toU8(data, u8_data); + } else if (type == DType_INT48) { + std::vector data; + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getInt()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + TosaSerializationHandler::ConvertI48toU8(data, u8_data); + } else if (type == DType_BOOL) { + std::vector data; + + auto val_attr = attr.dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val); + } + } else if (val_attr) { + data.push_back(val_attr.getValue()); + } else { + llvm::errs() << "Unknown const attribute\n"; + return mlir::failure(); + } + + TosaSerializationHandler::ConvertBooltoU8(data, u8_data); + } else { + llvm::errs() << "Unknown element type of const attribute\n"; + return mlir::failure(); + } + + return mlir::success(); +} + // Main template to catch unimplemented translation. template TosaSerializationOperator * @@ -473,7 +632,13 @@ TosaSerializationOperatorBuilder::build( #endif // Update tensor.data array with Const value attribute + mlir::Attribute value_attr = op.getAttr("value"); + if (!value_attr) { + op.emitOpError("ERROR: tosa.const doesn't have value"); + return nullptr; + } std::vector u8_data; + DType type = ts->GetDtype(); if (type == DType_FP32) { std::vector data; @@ -590,8 +755,8 @@ TosaSerializationOperatorBuilder::build( op.emitOpError("Unknown element type of const attribute"); return nullptr; } - ts->SetData(u8_data); + ts->SetData(u8_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONST, Attribute_NONE, nullptr, std::vector{}, std::vector{output_name}); @@ -1455,6 +1620,38 @@ TosaSerializationOperatorBuilder::build( return tyop; } +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + + std::string input_name = GetVariableTensorName(&op); + std::string output_name = GetTensorName(op.getResult(0)); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetVariableTensorName(&op); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + /* End translating TOSA operator */ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion( bool is_top, std::vector &return_values) { @@ -1509,9 +1706,11 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( // Build tensor_map for (auto &op : block->getOperations()) { - if (!(llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op))) { + if (llvm::isa(op)) { + RegisterVariableOp(op); + } else if (!(llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op))) { for (uint32_t i = 0; i < op.getNumResults(); i++) { std::string intermediate_tensor_name = "layer_" + std::to_string(intermediate_tensor_index++); @@ -1553,6 +1752,53 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( } } + // Build variable tensor + for (auto pair : variable_tensor_op_map) { + mlir::Operation *op = pair.second; + mlir::Value val = op->getResult(0); + mlir::RankedTensorType tensor_type = op->getAttr("type") + .cast() + .getValue() + .cast(); + ser_tensor = BuildTosaSerializationVariableTensor( + tensor_type /* tensor_type */, pair.first /* flatbuffer name */); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); + } + // Initialize if "initial_value" attribute exists. If not, set data to all + // zeros + mlir::Attribute initial_value = op->getAttr("initial_value"); + std::vector u8_data; + DType element_type = Type2DType(tensor_type.getElementType()); + if (initial_value) { + if (initial_value.isa()) { + if (op_builder + .GetDataFromAttribute(initial_value, element_type, u8_data) + .failed()) { + llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " + "initial_value of variable tensor\n"; + return mlir::failure(); + } + } else { + llvm::errs() << "ERROR: Unknown initial_value attribute type\n"; + return mlir::failure(); + } + } else { + uint64_t num_elements = 1; + for (int64_t dim : tensor_type.getShape()) { + num_elements *= dim; + } + uint64_t num_bits = num_elements * GetDTypeSize(element_type); + uint64_t num_bytes = + (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1; + // std::fill_n(u8_data.begin(), num_bytes, 0); + TosaSerializationHandler::ForceAlignTensorData(u8_data); + } + ser_tensor->SetData(u8_data); + ser_block->GetTensors().push_back(ser_tensor); + } + // Build tensor // The tensor_map is sorted by hashed mlir::Value types. @@ -1563,6 +1809,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( tensor_name_sort[pair.second] = pair.first; for (auto pair : tensor_name_sort) { + mlir::RankedTensorType tensor_type = + pair.second.getType().cast(); ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, pair.first /* name */); if (!ser_tensor) { @@ -1576,7 +1824,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( for (auto &op : block->getOperations()) { if (llvm::isa(op) || llvm::isa(op) || - llvm::isa(op)) + llvm::isa(op) || + llvm::isa(op)) continue; ser_operator = BuildTosaSerializationOperator(op_builder, op); if (!ser_operator) { @@ -1585,7 +1834,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( } ser_block->GetOperators().push_back(ser_operator); } - return mlir::success(); } @@ -1594,7 +1842,10 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { TosaSerializationOperator *target_operator = nullptr; - if (false) { + if (llvm::isa(op)) { + target_operator = op_builder.build(op); + } else if (llvm::isa(op)) { + target_operator = op_builder.build(op); } #define DEF_OPERATOR(MLIR_OP) \ else if (llvm::isa(op)) { \ @@ -1613,18 +1864,22 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( return nullptr; } + if (llvm::isa(op) || + llvm::isa(op)) { + return target_operator; + } + // Sanity check the number of inputs/outputs of TOSA dialect matches the // number of TOSA flatbuffer if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { - llvm::errs() << op << "\n"; - llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands() + llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands() << " input tensors != Flatbuffer " "operator has " << target_operator->GetInputTensorNames().size() << " input tensors\n"; } if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) { - llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults() + llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults() << " output tensors != Flatbuffer " "operator has " << target_operator->GetOutputTensorNames().size() @@ -1634,6 +1889,28 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( return target_operator; } +TosaSerializationTensor * +TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( + mlir::RankedTensorType tensor_type, const std::string &name) { + // If tensor already created before, use that tensor directly, create a new + // one otherwise + TosaSerializationTensor *ts = ser_block->GetTensorByName(name); + if (ts) { + return nullptr; + } + + std::vector shape(tensor_type.getShape().begin(), + tensor_type.getShape().end()); + + DType type = Type2DType(tensor_type.getElementType()); + + ts = new TosaSerializationTensor(name, shape, type, std::vector(), + /* is_variable = */ true, + /* is_unranked = */ false); + + return ts; +} + TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::Value val, const std::string &name) { @@ -1756,20 +2033,42 @@ mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) { return mlir::success(); } -namespace mlir { +#define GEN_PASS_DEF_TOSASERIALIZATIONPASS +namespace mlir { namespace tosa { - namespace { class TosaSerialize : public TosaSerializationPassBase { public: void runOnOperation() final { - auto function = getOperation(); + auto moduleOp = getOperation(); - if (dumpTosaFlatbuffer(function).failed()) { - llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; - return signalPassFailure(); + // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if + // that's a func.funcOp + + auto regions = moduleOp->getRegions(); + auto region_size = regions.size(); + + auto region_0 = regions.begin(); + auto block_size = region_0->getBlocks().size(); + + auto block_0 = region_0->getBlocks().begin(); + + auto op_size = block_0->getOperations().size(); + + for (auto it = block_0->getOperations().begin(); + it != block_0->getOperations().end(); ++it) { + // read variableOps that are declared outside of functionOps + if (llvm::isa(*it)) { + RegisterVariableOp(*it); + } else if (llvm::isa(*it)) { + auto funcOp = dyn_cast((*it)); + if (dumpTosaFlatbuffer(funcOp).failed()) { + llvm::errs() << "Failed to generate TOSA flatbuffer...\n"; + return signalPassFailure(); + } + } } } }; @@ -1790,7 +2089,7 @@ public: } // anonymous namespace // Creates an instance of the TOSA flatbuffer generation pass -std::unique_ptr createTosaSerializePass() { +std::unique_ptr> createTosaSerializePass() { return std::make_unique(); } -- cgit v1.2.1