aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-09-19 21:07:38 -0700
committerTai Ly <tai.ly@arm.com>2023-09-07 13:17:16 -0700
commitedc04e0ef1a93d1bf4578308d620dd49c10e1ad4 (patch)
tree1044165c16528a17337f81aa15361c7b77aafe7f /src/TosaSerialize.cpp
parent7566d1235cb646e46531c2eb34757cb4b3efa933 (diff)
downloadtosa_mlir_translator-edc04e0ef1a93d1bf4578308d620dd49c10e1ad4.tar.gz
[tosa_mlir_translator] Support Tosa StatefulOpsv0.90a0
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I6ce5a917cada436f6a80e6d85e670d6cd44e01e9
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp357
1 files changed, 328 insertions, 29 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index f74df1d..941a75e 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -21,9 +21,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
-#include "mlir/Pass/Pass.h" // from @llvm-project
-#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Support/TypeID.h"
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tosa_serialization_handler.h"
+#include <algorithm>
#include <functional>
#include <map>
#include <unordered_map>
@@ -77,7 +81,7 @@ static DType Type2DType(mlir::Type element_type) {
} else if (element_type.isUnsignedInteger(8)) {
return DType_UINT8;
} else if (element_type.isInteger(4)) {
- return DType_INT8;
+ return DType_INT4;
} else if (element_type.isInteger(8)) {
return DType_INT8;
} else if (element_type.isInteger(16)) {
@@ -97,16 +101,27 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-static DType Type2AccumDType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
- return DType_FP32;
- } else if (element_type.isInteger(8)) {
- return DType_INT32;
- } else if (element_type.isInteger(16)) {
- return DType_INT48;
+// Returns number of bits TOSA flatbuffer store in tensor raw data array
+uint64_t GetDTypeSize(DType dtype) {
+ switch (dtype) {
+ case DType_INT4:
+ return 4;
+ case DType_BOOL:
+ case DType_UINT8:
+ case DType_INT8:
+ return 8;
+ case DType_INT16:
+ return 16;
+ case DType_FP32:
+ case DType_INT32:
+ return 32;
+ case DType_INT48:
+ return 48;
+ default:
+ llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype]
+ << "\n";
+ return 1;
}
- return DType_UNKNOWN;
}
static DType Type2PoolAccumDType(mlir::Type element_type) {
@@ -123,6 +138,26 @@ static DType Type2PoolAccumDType(mlir::Type element_type) {
class TosaSerializationBlockBuilder;
class TosaSerializationRegionBuilder;
+std::unordered_map<std::string, mlir::Operation *> variable_tensor_op_map;
+std::unordered_map<std::string, std::string>
+ variable_tensor_flatbuffer_name_map;
+static int variable_tensor_index = 0;
+
+namespace {
+
+// for now, this is a global map of variables
+void RegisterVariableOp(mlir::Operation &op) {
+ std::string variable_tensor_flatbuffer_name =
+ "Variable_" + std::to_string(variable_tensor_index++);
+ std::string variable_tensor_mlir_name =
+ op.getAttr("name").cast<mlir::StringAttr>().getValue().str();
+ variable_tensor_op_map[variable_tensor_flatbuffer_name] = &op;
+ variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name] =
+ variable_tensor_flatbuffer_name;
+}
+
+} // namespace
+
class TosaSerializationOperatorBuilder {
public:
TosaSerializationOperatorBuilder(
@@ -133,9 +168,12 @@ public:
TosaSerializationOperator *build(mlir::Operation &op) const;
TosaSerializationHandler *GetTsh() const;
TosaSerializationRegionBuilder *GetRegionBuilder() const;
+ mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype,
+ std::vector<uint8_t> &u8_data) const;
private:
std::string GetTensorName(mlir::Value val) const;
+ std::string GetVariableTensorName(mlir::Operation *op) const;
TosaSerializationOperator *BuildPoolOpFromMlirOp(mlir::Operation &op,
Op opcode) const;
TosaSerializationOperator *BuildEwiseBinaryOpFromMlirOp(mlir::Operation &op,
@@ -171,6 +209,9 @@ private:
TosaSerializationOperator *BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
TosaSerializationTensor *
+ BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type,
+ const std::string &name);
+ TosaSerializationTensor *
BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
TosaSerializationBasicBlock *ser_block;
@@ -262,6 +303,124 @@ static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
return vec;
}
+std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
+ mlir::Operation *op) const {
+ mlir::Attribute variable_op_name_attr = op->getAttr("name");
+ std::string variable_tensor_mlir_name =
+ op->getAttr("name").cast<mlir::FlatSymbolRefAttr>().getValue().str();
+
+ if (variable_tensor_flatbuffer_name_map.find(variable_tensor_mlir_name) ==
+ variable_tensor_flatbuffer_name_map.end()) {
+ llvm::errs() << "ERROR: Failed to find key " << variable_tensor_mlir_name
+ << " from variable_tensor_flatbuffer_name_map\n";
+ assert(0);
+ }
+ return variable_tensor_flatbuffer_name_map[variable_tensor_mlir_name];
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
+ mlir::Attribute &attr, DType type, std::vector<uint8_t> &u8_data) const {
+ auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
+ if (type == DType_FP32) {
+ std::vector<float> data;
+ auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<float>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ } else if (type == DType_INT8) {
+ std::vector<int8_t> data;
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int8_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+ TosaSerializationHandler::ConvertI8toU8(data, u8_data);
+ } else if (type == DType_INT16) {
+ std::vector<int16_t> data;
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int16_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+ TosaSerializationHandler::ConvertI16toU8(data, u8_data);
+ } else if (type == DType_INT32) {
+ std::vector<int32_t> data;
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int32_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+ TosaSerializationHandler::ConvertI32toU8(data, u8_data);
+ } else if (type == DType_INT48) {
+ std::vector<int64_t> data;
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<int64_t>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getInt());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+ TosaSerializationHandler::ConvertI48toU8(data, u8_data);
+ } else if (type == DType_BOOL) {
+ std::vector<bool> data;
+
+ auto val_attr = attr.dyn_cast<mlir::BoolAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<bool>()) {
+ data.push_back(val);
+ }
+ } else if (val_attr) {
+ data.push_back(val_attr.getValue());
+ } else {
+ llvm::errs() << "Unknown const attribute\n";
+ return mlir::failure();
+ }
+
+ TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
+ } else {
+ llvm::errs() << "Unknown element type of const attribute\n";
+ return mlir::failure();
+ }
+
+ return mlir::success();
+}
+
// Main template to catch unimplemented translation.
template <typename T>
TosaSerializationOperator *
@@ -473,7 +632,13 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
#endif
// Update tensor.data array with Const value attribute
+ mlir::Attribute value_attr = op.getAttr("value");
+ if (!value_attr) {
+ op.emitOpError("ERROR: tosa.const doesn't have value");
+ return nullptr;
+ }
std::vector<uint8_t> u8_data;
+
DType type = ts->GetDtype();
if (type == DType_FP32) {
std::vector<float> data;
@@ -590,8 +755,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
op.emitOpError("Unknown element type of const attribute");
return nullptr;
}
- ts->SetData(u8_data);
+ ts->SetData(u8_data);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONST, Attribute_NONE, nullptr, std::vector<std::string>{},
std::vector<std::string>{output_name});
@@ -1455,6 +1620,38 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
return tyop;
}
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableReadOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetVariableTensorName(&op);
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::VariableWriteOp>(
+ mlir::Operation &op) const {
+
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetVariableTensorName(&op);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_IDENTITY, Attribute_NONE, nullptr,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
/* End translating TOSA operator */
mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(
bool is_top, std::vector<mlir::Value> &return_values) {
@@ -1509,9 +1706,11 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
// Build tensor_map
for (auto &op : block->getOperations()) {
- if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
- llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))) {
+ if (llvm::isa<mlir::tosa::VariableOp>(op)) {
+ RegisterVariableOp(op);
+ } else if (!(llvm::isa<mlir::tosa::YieldOp>(op) ||
+ llvm::isa<mlir::func::ReturnOp>(op) ||
+ llvm::isa<mlir::tensor::CastOp>(op))) {
for (uint32_t i = 0; i < op.getNumResults(); i++) {
std::string intermediate_tensor_name =
"layer_" + std::to_string(intermediate_tensor_index++);
@@ -1553,6 +1752,53 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
}
}
+ // Build variable tensor
+ for (auto pair : variable_tensor_op_map) {
+ mlir::Operation *op = pair.second;
+ mlir::Value val = op->getResult(0);
+ mlir::RankedTensorType tensor_type = op->getAttr("type")
+ .cast<mlir::TypeAttr>()
+ .getValue()
+ .cast<mlir::RankedTensorType>();
+ ser_tensor = BuildTosaSerializationVariableTensor(
+ tensor_type /* tensor_type */, pair.first /* flatbuffer name */);
+ if (!ser_tensor) {
+ llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
+ return mlir::failure();
+ }
+ // Initialize if "initial_value" attribute exists. If not, set data to all
+ // zeros
+ mlir::Attribute initial_value = op->getAttr("initial_value");
+ std::vector<uint8_t> u8_data;
+ DType element_type = Type2DType(tensor_type.getElementType());
+ if (initial_value) {
+ if (initial_value.isa<mlir::DenseElementsAttr>()) {
+ if (op_builder
+ .GetDataFromAttribute(initial_value, element_type, u8_data)
+ .failed()) {
+ llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
+ "initial_value of variable tensor\n";
+ return mlir::failure();
+ }
+ } else {
+ llvm::errs() << "ERROR: Unknown initial_value attribute type\n";
+ return mlir::failure();
+ }
+ } else {
+ uint64_t num_elements = 1;
+ for (int64_t dim : tensor_type.getShape()) {
+ num_elements *= dim;
+ }
+ uint64_t num_bits = num_elements * GetDTypeSize(element_type);
+ uint64_t num_bytes =
+ (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1;
+ // std::fill_n(u8_data.begin(), num_bytes, 0);
+ TosaSerializationHandler::ForceAlignTensorData(u8_data);
+ }
+ ser_tensor->SetData(u8_data);
+ ser_block->GetTensors().push_back(ser_tensor);
+ }
+
// Build tensor
// The tensor_map is sorted by hashed mlir::Value types.
@@ -1563,6 +1809,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
tensor_name_sort[pair.second] = pair.first;
for (auto pair : tensor_name_sort) {
+ mlir::RankedTensorType tensor_type =
+ pair.second.getType().cast<mlir::RankedTensorType>();
ser_tensor = BuildTosaSerializationTensor(pair.second /* val */,
pair.first /* name */);
if (!ser_tensor) {
@@ -1576,7 +1824,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
for (auto &op : block->getOperations()) {
if (llvm::isa<mlir::tosa::YieldOp>(op) ||
llvm::isa<mlir::func::ReturnOp>(op) ||
- llvm::isa<mlir::tensor::CastOp>(op))
+ llvm::isa<mlir::tensor::CastOp>(op) ||
+ llvm::isa<mlir::tosa::VariableOp>(op))
continue;
ser_operator = BuildTosaSerializationOperator(op_builder, op);
if (!ser_operator) {
@@ -1585,7 +1834,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
}
ser_block->GetOperators().push_back(ser_operator);
}
-
return mlir::success();
}
@@ -1594,7 +1842,10 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
TosaSerializationOperator *target_operator = nullptr;
- if (false) {
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableReadOp>(op);
+ } else if (llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ target_operator = op_builder.build<mlir::tosa::VariableWriteOp>(op);
}
#define DEF_OPERATOR(MLIR_OP) \
else if (llvm::isa<mlir::tosa::MLIR_OP##Op>(op)) { \
@@ -1613,18 +1864,22 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
return nullptr;
}
+ if (llvm::isa<mlir::tosa::VariableReadOp>(op) ||
+ llvm::isa<mlir::tosa::VariableWriteOp>(op)) {
+ return target_operator;
+ }
+
// Sanity check the number of inputs/outputs of TOSA dialect matches the
// number of TOSA flatbuffer
if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) {
- llvm::errs() << op << "\n";
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumOperands()
<< " input tensors != Flatbuffer "
"operator has "
<< target_operator->GetInputTensorNames().size()
<< " input tensors\n";
}
if (op.getNumResults() != target_operator->GetOutputTensorNames().size()) {
- llvm::errs() << "WARNING. MLIR operator has " << op.getNumResults()
+ llvm::errs() << "WARNING: MLIR operator has " << op.getNumResults()
<< " output tensors != Flatbuffer "
"operator has "
<< target_operator->GetOutputTensorNames().size()
@@ -1635,6 +1890,28 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
}
TosaSerializationTensor *
+TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
+ mlir::RankedTensorType tensor_type, const std::string &name) {
+ // If tensor already created before, use that tensor directly, create a new
+ // one otherwise
+ TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
+ if (ts) {
+ return nullptr;
+ }
+
+ std::vector<int32_t> shape(tensor_type.getShape().begin(),
+ tensor_type.getShape().end());
+
+ DType type = Type2DType(tensor_type.getElementType());
+
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* is_variable = */ true,
+ /* is_unranked = */ false);
+
+ return ts;
+}
+
+TosaSerializationTensor *
TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
mlir::Value val, const std::string &name) {
// If tensor already created before, use that tensor directly, create a new
@@ -1756,20 +2033,42 @@ mlir::LogicalResult dumpTosaJSON(mlir::func::FuncOp &func) {
return mlir::success();
}
-namespace mlir {
+#define GEN_PASS_DEF_TOSASERIALIZATIONPASS
+namespace mlir {
namespace tosa {
-
namespace {
class TosaSerialize : public TosaSerializationPassBase<TosaSerialize> {
public:
void runOnOperation() final {
- auto function = getOperation();
+ auto moduleOp = getOperation();
- if (dumpTosaFlatbuffer(function).failed()) {
- llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
- return signalPassFailure();
+ // iterate through each op in the moduleOp, call dumpTosaFlatbuffer if
+ // that's a func.funcOp
+
+ auto regions = moduleOp->getRegions();
+ auto region_size = regions.size();
+
+ auto region_0 = regions.begin();
+ auto block_size = region_0->getBlocks().size();
+
+ auto block_0 = region_0->getBlocks().begin();
+
+ auto op_size = block_0->getOperations().size();
+
+ for (auto it = block_0->getOperations().begin();
+ it != block_0->getOperations().end(); ++it) {
+ // read variableOps that are declared outside of functionOps
+ if (llvm::isa<mlir::tosa::VariableOp>(*it)) {
+ RegisterVariableOp(*it);
+ } else if (llvm::isa<mlir::func::FuncOp>(*it)) {
+ auto funcOp = dyn_cast<mlir::func::FuncOp>((*it));
+ if (dumpTosaFlatbuffer(funcOp).failed()) {
+ llvm::errs() << "Failed to generate TOSA flatbuffer...\n";
+ return signalPassFailure();
+ }
+ }
}
}
};
@@ -1790,7 +2089,7 @@ public:
} // anonymous namespace
// Creates an instance of the TOSA flatbuffer generation pass
-std::unique_ptr<Pass> createTosaSerializePass() {
+std::unique_ptr<OperationPass<ModuleOp>> createTosaSerializePass() {
return std::make_unique<TosaSerialize>();
}