diff options
Diffstat (limited to 'src/common')
-rw-r--r-- | src/common/ITensorV2.cpp | 39 | ||||
-rw-r--r-- | src/common/ITensorV2.h (renamed from src/common/ITensor.h) | 14 | ||||
-rw-r--r-- | src/common/TensorPack.cpp | 2 | ||||
-rw-r--r-- | src/common/utils/LegacySupport.cpp | 53 | ||||
-rw-r--r-- | src/common/utils/LegacySupport.h | 7 | ||||
-rw-r--r-- | src/common/utils/Utils.h | 4 |
6 files changed, 111 insertions, 8 deletions
diff --git a/src/common/ITensorV2.cpp b/src/common/ITensorV2.cpp new file mode 100644 index 0000000000..39bf1c6fb3 --- /dev/null +++ b/src/common/ITensorV2.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/common/ITensorV2.h" +#include "arm_compute/core/TensorInfo.h" +#include "src/common/utils/LegacySupport.h" + +namespace arm_compute +{ +size_t ITensorV2::get_size() const +{ + return tensor()->info()->total_size(); +} + +AclTensorDescriptor ITensorV2::get_descriptor() const +{ + return detail::convert_to_descriptor(*tensor()->info()); +} +} // namespace arm_compute
\ No newline at end of file diff --git a/src/common/ITensor.h b/src/common/ITensorV2.h index ee7eac7688..965aacea23 100644 --- a/src/common/ITensor.h +++ b/src/common/ITensorV2.h @@ -92,7 +92,19 @@ public: * * @return The legacy underlying tensor object */ - virtual arm_compute::ITensor *tensor() = 0; + virtual arm_compute::ITensor *tensor() const = 0; + /** Get the size of the tensor in byte + * + * @note The size isn't based on allocated memory, but based on information in its descriptor (dimensions, data type, etc.). + * + * @return The size of the tensor in byte + */ + size_t get_size() const; + /** Get the descriptor of this tensor + * + * @return The descriptor describing the characteristics of this tensor + */ + AclTensorDescriptor get_descriptor() const; }; /** Extract internal representation of a Tensor diff --git a/src/common/TensorPack.cpp b/src/common/TensorPack.cpp index c582c7b106..6c2c7f9622 100644 --- a/src/common/TensorPack.cpp +++ b/src/common/TensorPack.cpp @@ -22,7 +22,7 @@ * SOFTWARE. */ #include "src/common/TensorPack.h" -#include "src/common/ITensor.h" +#include "src/common/ITensorV2.h" #include "src/common/utils/Validate.h" namespace arm_compute diff --git a/src/common/utils/LegacySupport.cpp b/src/common/utils/LegacySupport.cpp index 5981c657bd..569b2abd89 100644 --- a/src/common/utils/LegacySupport.cpp +++ b/src/common/utils/LegacySupport.cpp @@ -29,7 +29,7 @@ namespace detail { namespace { -DataType data_type_mapper(AclDataType data_type) +DataType convert_to_legacy_data_type(AclDataType data_type) { switch(data_type) { @@ -41,11 +41,25 @@ DataType data_type_mapper(AclDataType data_type) return DataType::BFLOAT16; default: return DataType::UNKNOWN; - ; } } -TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) +AclDataType convert_to_c_data_type(DataType data_type) +{ + switch(data_type) + { + case DataType::F32: + return AclDataType::AclFloat32; + case DataType::F16: + return AclDataType::AclFloat16; + case DataType::BFLOAT16: + return AclDataType::AclBFloat16; + default: + return AclDataType::AclDataTypeUnknown; + } +} + +TensorShape create_legacy_tensor_shape(int32_t ndims, int32_t *shape) { TensorShape legacy_shape{}; for(int32_t d = 0; d < ndims; ++d) @@ -54,13 +68,44 @@ TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) } return legacy_shape; } +int32_t *create_tensor_shape_array(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + if(num_dims <= 0) + { + return nullptr; + } + + int32_t *shape_array = new int32_t[num_dims]; + + for(size_t d = 0; d < num_dims; ++d) + { + shape_array[d] = info.tensor_shape()[d]; + } + + return shape_array; +} } // namespace TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc) { TensorInfo legacy_desc; - legacy_desc.init(tensor_shape_mapper(desc.ndims, desc.shape), 1, data_type_mapper(desc.data_type)); + legacy_desc.init(create_legacy_tensor_shape(desc.ndims, desc.shape), 1, convert_to_legacy_data_type(desc.data_type)); return legacy_desc; } + +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + AclTensorDescriptor desc + { + static_cast<int32_t>(num_dims), + create_tensor_shape_array(info), + convert_to_c_data_type(info.data_type()), + nullptr, + 0 + }; + return desc; +} } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/LegacySupport.h b/src/common/utils/LegacySupport.h index 37329b747c..c2cc1bc182 100644 --- a/src/common/utils/LegacySupport.h +++ b/src/common/utils/LegacySupport.h @@ -38,6 +38,13 @@ namespace detail * @return Legacy tensor meta-data */ TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc); +/** Convert a legacy tensor meta-data to a descriptor + * + * @param[in] info Legacy tensor meta-data + * + * @return A converted descriptor + */ +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info); } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/Utils.h b/src/common/utils/Utils.h index 87be9df509..79f4f39c47 100644 --- a/src/common/utils/Utils.h +++ b/src/common/utils/Utils.h @@ -40,7 +40,7 @@ namespace utils * @return A corresponding plain old C enumeration */ template <typename E, typename SE> -constexpr E as_cenum(SE v) noexcept +constexpr E as_cenum(const SE v) noexcept { return static_cast<E>(static_cast<std::underlying_type_t<SE>>(v)); } @@ -55,7 +55,7 @@ constexpr E as_cenum(SE v) noexcept * @return A corresponding strongly typed enumeration */ template <typename SE, typename E> -constexpr SE as_enum(E val) noexcept +constexpr SE as_enum(const E val) noexcept { return static_cast<SE>(val); } |