From 5d580faec02bcef56164587accb5fd88a3e80d86 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 15 Dec 2023 20:34:51 +0000 Subject: [serialization_lib] Add tosa shape ops Added tosa shape ops to tosa.fbs also added convert I64 to and from U8 for storing const_shape data values Signed-off-by: Tai Ly Change-Id: I1e938dec7398fbcbe5be657dad65cdd61af5b597 --- .gitignore | 3 ++- include/tosa_generated.h | 28 ++++++++++++++++---- include/tosa_serialization_handler.h | 2 ++ python/tosa/Op.py | 6 +++++ schema/tosa.fbs | 8 +++++- src/tosa_serialization_handler.cpp | 50 +++++++++++++++++++++++++++++++++++- 6 files changed, 89 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index ba0430d..5034363 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__pycache__/ \ No newline at end of file +__pycache__/ +build/ \ No newline at end of file diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 0547316..2ecd35a 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -272,11 +272,17 @@ enum Op : uint32_t { Op_RFFT2D = 70, Op_ERF = 71, Op_DIM = 72, + Op_CONST_SHAPE = 73, + Op_CONCAT_SHAPE = 74, + Op_ADD_SHAPE = 75, + Op_SUB_SHAPE = 76, + Op_MUL_SHAPE = 77, + Op_DIV_SHAPE = 78, Op_MIN = Op_UNKNOWN, - Op_MAX = Op_DIM + Op_MAX = Op_DIV_SHAPE }; -inline const Op (&EnumValuesOp())[73] { +inline const Op (&EnumValuesOp())[79] { static const Op values[] = { Op_UNKNOWN, Op_ARGMAX, @@ -350,13 +356,19 @@ inline const Op (&EnumValuesOp())[73] { Op_FFT2D, Op_RFFT2D, Op_ERF, - Op_DIM + Op_DIM, + Op_CONST_SHAPE, + Op_CONCAT_SHAPE, + Op_ADD_SHAPE, + Op_SUB_SHAPE, + Op_MUL_SHAPE, + Op_DIV_SHAPE }; return values; } inline const char * const *EnumNamesOp() { - static const char * const names[74] = { + static const char * const names[80] = { "UNKNOWN", "ARGMAX", "AVG_POOL2D", @@ -430,13 +442,19 @@ inline const char * const *EnumNamesOp() { "RFFT2D", "ERF", "DIM", + "CONST_SHAPE", + "CONCAT_SHAPE", + "ADD_SHAPE", + "SUB_SHAPE", + "MUL_SHAPE", + "DIV_SHAPE", nullptr }; return names; } inline const char *EnumNameOp(Op e) { - if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_DIM)) return ""; + if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_DIV_SHAPE)) return ""; const size_t index = static_cast(e); return EnumNamesOp()[index]; } diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 41032fc..35327e8 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -414,6 +414,7 @@ public: // data format conversion. little-endian. static tosa_err_t ConvertF16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF32toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertI64toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI48toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI16toU8(const std::vector& in, std::vector& out); @@ -423,6 +424,7 @@ public: static tosa_err_t ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toI64(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toI48(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toI32(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toI16(const std::vector& in, uint32_t out_size, std::vector& out); diff --git a/python/tosa/Op.py b/python/tosa/Op.py index d628479..c34f109 100644 --- a/python/tosa/Op.py +++ b/python/tosa/Op.py @@ -76,3 +76,9 @@ class Op(object): RFFT2D = 70 ERF = 71 DIM = 72 + CONST_SHAPE = 73 + CONCAT_SHAPE = 74 + ADD_SHAPE = 75 + SUB_SHAPE = 76 + MUL_SHAPE = 77 + DIV_SHAPE = 78 diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 1975667..171818d 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -119,6 +119,12 @@ enum Op:uint32 { RFFT2D, ERF, DIM, + CONST_SHAPE, + CONCAT_SHAPE, + ADD_SHAPE, + SUB_SHAPE, + MUL_SHAPE, + DIV_SHAPE, } union Attribute { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 015dda4..453670f 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -777,6 +777,25 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector& in return TOSA_OK; } +tosa_err_t TosaSerializationHandler::ConvertI64toU8(const std::vector& in, std::vector& out) +{ + out.clear(); + for (auto val : in) + { + uint64_t* val_u64 = reinterpret_cast(&val); + out.push_back(*val_u64 & 0xFF); + out.push_back((*val_u64 >> 8) & 0xFF); + out.push_back((*val_u64 >> 16) & 0xFF); + out.push_back((*val_u64 >> 24) & 0xFF); + out.push_back((*val_u64 >> 32) & 0xFF); + out.push_back((*val_u64 >> 40) & 0xFF); + out.push_back((*val_u64 >> 48) & 0xFF); + out.push_back((*val_u64 >> 56) & 0xFF); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector& in, std::vector& out) { out.clear(); @@ -926,6 +945,35 @@ tosa_err_t return TOSA_OK; } +tosa_err_t TosaSerializationHandler::ConvertU8toI64(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + out.clear(); + if (in.size() < out_size * sizeof(int64_t)) + { + printf("TosaSerializationHandler::ConvertU8toI64(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int64_t)); + return TOSA_USER_ERROR; + } + for (uint32_t i = 0; i < out_size; i++) + { + uint64_t byte0 = in[i * sizeof(int64_t)]; + uint64_t byte1 = in[i * sizeof(int64_t) + 1]; + uint64_t byte2 = in[i * sizeof(int64_t) + 2]; + uint64_t byte3 = in[i * sizeof(int64_t) + 3]; + uint64_t byte4 = in[i * sizeof(int64_t) + 4]; + uint64_t byte5 = in[i * sizeof(int64_t) + 5]; + uint64_t byte6 = in[i * sizeof(int64_t) + 6]; + uint64_t byte7 = in[i * sizeof(int64_t) + 7]; + uint64_t val_u64 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24) + (byte4 << 32) + (byte5 << 40) + + (byte6 << 48) + (byte7 << 56); + int64_t* val_i64 = reinterpret_cast(&val_u64); + out.push_back(*val_i64); + } + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertU8toI48(const std::vector& in, uint32_t out_size, std::vector& out) -- cgit v1.2.1