aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-12-15 20:34:51 +0000
committerTai Ly <tai.ly@arm.com>2024-01-12 21:39:59 +0000
commit5d580faec02bcef56164587accb5fd88a3e80d86 (patch)
tree0bf550d054f37d2831ac80e1e901265e905d1f1c
parentcc426df2a6762cb09c6a25c911039ae34660570c (diff)
downloadserialization_lib-5d580faec02bcef56164587accb5fd88a3e80d86.tar.gz
[serialization_lib] Add tosa shape ops
Added tosa shape ops to tosa.fbs also added convert I64 to and from U8 for storing const_shape data values Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I1e938dec7398fbcbe5be657dad65cdd61af5b597
-rw-r--r--.gitignore3
-rw-r--r--include/tosa_generated.h28
-rw-r--r--include/tosa_serialization_handler.h2
-rw-r--r--python/tosa/Op.py6
-rw-r--r--schema/tosa.fbs8
-rw-r--r--src/tosa_serialization_handler.cpp50
6 files changed, 89 insertions, 8 deletions
diff --git a/.gitignore b/.gitignore
index ba0430d..5034363 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
-__pycache__/ \ No newline at end of file
+__pycache__/
+build/ \ No newline at end of file
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 0547316..2ecd35a 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -272,11 +272,17 @@ enum Op : uint32_t {
Op_RFFT2D = 70,
Op_ERF = 71,
Op_DIM = 72,
+ Op_CONST_SHAPE = 73,
+ Op_CONCAT_SHAPE = 74,
+ Op_ADD_SHAPE = 75,
+ Op_SUB_SHAPE = 76,
+ Op_MUL_SHAPE = 77,
+ Op_DIV_SHAPE = 78,
Op_MIN = Op_UNKNOWN,
- Op_MAX = Op_DIM
+ Op_MAX = Op_DIV_SHAPE
};
-inline const Op (&EnumValuesOp())[73] {
+inline const Op (&EnumValuesOp())[79] {
static const Op values[] = {
Op_UNKNOWN,
Op_ARGMAX,
@@ -350,13 +356,19 @@ inline const Op (&EnumValuesOp())[73] {
Op_FFT2D,
Op_RFFT2D,
Op_ERF,
- Op_DIM
+ Op_DIM,
+ Op_CONST_SHAPE,
+ Op_CONCAT_SHAPE,
+ Op_ADD_SHAPE,
+ Op_SUB_SHAPE,
+ Op_MUL_SHAPE,
+ Op_DIV_SHAPE
};
return values;
}
inline const char * const *EnumNamesOp() {
- static const char * const names[74] = {
+ static const char * const names[80] = {
"UNKNOWN",
"ARGMAX",
"AVG_POOL2D",
@@ -430,13 +442,19 @@ inline const char * const *EnumNamesOp() {
"RFFT2D",
"ERF",
"DIM",
+ "CONST_SHAPE",
+ "CONCAT_SHAPE",
+ "ADD_SHAPE",
+ "SUB_SHAPE",
+ "MUL_SHAPE",
+ "DIV_SHAPE",
nullptr
};
return names;
}
inline const char *EnumNameOp(Op e) {
- if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_DIM)) return "";
+ if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_DIV_SHAPE)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesOp()[index];
}
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 41032fc..35327e8 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -414,6 +414,7 @@ public:
// data format conversion. little-endian.
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
@@ -423,6 +424,7 @@ public:
static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toI64(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
diff --git a/python/tosa/Op.py b/python/tosa/Op.py
index d628479..c34f109 100644
--- a/python/tosa/Op.py
+++ b/python/tosa/Op.py
@@ -76,3 +76,9 @@ class Op(object):
RFFT2D = 70
ERF = 71
DIM = 72
+ CONST_SHAPE = 73
+ CONCAT_SHAPE = 74
+ ADD_SHAPE = 75
+ SUB_SHAPE = 76
+ MUL_SHAPE = 77
+ DIV_SHAPE = 78
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 1975667..171818d 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -119,6 +119,12 @@ enum Op:uint32 {
RFFT2D,
ERF,
DIM,
+ CONST_SHAPE,
+ CONCAT_SHAPE,
+ ADD_SHAPE,
+ SUB_SHAPE,
+ MUL_SHAPE,
+ DIV_SHAPE,
}
union Attribute {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 015dda4..453670f 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -777,6 +777,25 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in
return TOSA_OK;
}
+tosa_err_t TosaSerializationHandler::ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out)
+{
+ out.clear();
+ for (auto val : in)
+ {
+ uint64_t* val_u64 = reinterpret_cast<uint64_t*>(&val);
+ out.push_back(*val_u64 & 0xFF);
+ out.push_back((*val_u64 >> 8) & 0xFF);
+ out.push_back((*val_u64 >> 16) & 0xFF);
+ out.push_back((*val_u64 >> 24) & 0xFF);
+ out.push_back((*val_u64 >> 32) & 0xFF);
+ out.push_back((*val_u64 >> 40) & 0xFF);
+ out.push_back((*val_u64 >> 48) & 0xFF);
+ out.push_back((*val_u64 >> 56) & 0xFF);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out)
{
out.clear();
@@ -926,6 +945,35 @@ tosa_err_t
return TOSA_OK;
}
+tosa_err_t TosaSerializationHandler::ConvertU8toI64(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<int64_t>& out)
+{
+ out.clear();
+ if (in.size() < out_size * sizeof(int64_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toI64(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int64_t));
+ return TOSA_USER_ERROR;
+ }
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ uint64_t byte0 = in[i * sizeof(int64_t)];
+ uint64_t byte1 = in[i * sizeof(int64_t) + 1];
+ uint64_t byte2 = in[i * sizeof(int64_t) + 2];
+ uint64_t byte3 = in[i * sizeof(int64_t) + 3];
+ uint64_t byte4 = in[i * sizeof(int64_t) + 4];
+ uint64_t byte5 = in[i * sizeof(int64_t) + 5];
+ uint64_t byte6 = in[i * sizeof(int64_t) + 6];
+ uint64_t byte7 = in[i * sizeof(int64_t) + 7];
+ uint64_t val_u64 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24) + (byte4 << 32) + (byte5 << 40) +
+ (byte6 << 48) + (byte7 << 56);
+ int64_t* val_i64 = reinterpret_cast<int64_t*>(&val_u64);
+ out.push_back(*val_i64);
+ }
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertU8toI48(const std::vector<uint8_t>& in,
uint32_t out_size,
std::vector<int64_t>& out)