diff options
author | Sang-Hoon Park <sang-hoon.park@arm.com> | 2021-03-31 15:18:16 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-04-19 07:35:54 +0000 |
commit | c6fcfb4adc37a6cf09472168dc177234d4fabdfa (patch) | |
tree | b67afd4c8d1594053395394b24406334e66e0791 /src/common/utils | |
parent | fe56edb4fd7a620fea4b6002d87a9763bdf8791a (diff) | |
download | ComputeLibrary-c6fcfb4adc37a6cf09472168dc177234d4fabdfa.tar.gz |
Add Tensor related utilities to the new API
A couple of utility functions to get the information
about tensors are added. Those functions are placed
at an additional header file for better grouping.
Related test cases are also added.
Resolves: COMPMID-4376
Change-Id: I6bd09cbf60fddcf4fe651906982397afb0451392
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5405
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/common/utils')
-rw-r--r-- | src/common/utils/LegacySupport.cpp | 53 | ||||
-rw-r--r-- | src/common/utils/LegacySupport.h | 7 | ||||
-rw-r--r-- | src/common/utils/Utils.h | 4 |
3 files changed, 58 insertions, 6 deletions
diff --git a/src/common/utils/LegacySupport.cpp b/src/common/utils/LegacySupport.cpp index 5981c657bd..569b2abd89 100644 --- a/src/common/utils/LegacySupport.cpp +++ b/src/common/utils/LegacySupport.cpp @@ -29,7 +29,7 @@ namespace detail { namespace { -DataType data_type_mapper(AclDataType data_type) +DataType convert_to_legacy_data_type(AclDataType data_type) { switch(data_type) { @@ -41,11 +41,25 @@ DataType data_type_mapper(AclDataType data_type) return DataType::BFLOAT16; default: return DataType::UNKNOWN; - ; } } -TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) +AclDataType convert_to_c_data_type(DataType data_type) +{ + switch(data_type) + { + case DataType::F32: + return AclDataType::AclFloat32; + case DataType::F16: + return AclDataType::AclFloat16; + case DataType::BFLOAT16: + return AclDataType::AclBFloat16; + default: + return AclDataType::AclDataTypeUnknown; + } +} + +TensorShape create_legacy_tensor_shape(int32_t ndims, int32_t *shape) { TensorShape legacy_shape{}; for(int32_t d = 0; d < ndims; ++d) @@ -54,13 +68,44 @@ TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) } return legacy_shape; } +int32_t *create_tensor_shape_array(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + if(num_dims <= 0) + { + return nullptr; + } + + int32_t *shape_array = new int32_t[num_dims]; + + for(size_t d = 0; d < num_dims; ++d) + { + shape_array[d] = info.tensor_shape()[d]; + } + + return shape_array; +} } // namespace TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc) { TensorInfo legacy_desc; - legacy_desc.init(tensor_shape_mapper(desc.ndims, desc.shape), 1, data_type_mapper(desc.data_type)); + legacy_desc.init(create_legacy_tensor_shape(desc.ndims, desc.shape), 1, convert_to_legacy_data_type(desc.data_type)); return legacy_desc; } + +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + AclTensorDescriptor desc + { + static_cast<int32_t>(num_dims), + create_tensor_shape_array(info), + convert_to_c_data_type(info.data_type()), + nullptr, + 0 + }; + return desc; +} } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/LegacySupport.h b/src/common/utils/LegacySupport.h index 37329b747c..c2cc1bc182 100644 --- a/src/common/utils/LegacySupport.h +++ b/src/common/utils/LegacySupport.h @@ -38,6 +38,13 @@ namespace detail * @return Legacy tensor meta-data */ TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc); +/** Convert a legacy tensor meta-data to a descriptor + * + * @param[in] info Legacy tensor meta-data + * + * @return A converted descriptor + */ +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info); } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/Utils.h b/src/common/utils/Utils.h index 87be9df509..79f4f39c47 100644 --- a/src/common/utils/Utils.h +++ b/src/common/utils/Utils.h @@ -40,7 +40,7 @@ namespace utils * @return A corresponding plain old C enumeration */ template <typename E, typename SE> -constexpr E as_cenum(SE v) noexcept +constexpr E as_cenum(const SE v) noexcept { return static_cast<E>(static_cast<std::underlying_type_t<SE>>(v)); } @@ -55,7 +55,7 @@ constexpr E as_cenum(SE v) noexcept * @return A corresponding strongly typed enumeration */ template <typename SE, typename E> -constexpr SE as_enum(E val) noexcept +constexpr SE as_enum(const E val) noexcept { return static_cast<SE>(val); } |