diff options
Diffstat (limited to 'src/common/utils')
-rw-r--r-- | src/common/utils/LegacySupport.cpp | 53 | ||||
-rw-r--r-- | src/common/utils/LegacySupport.h | 7 | ||||
-rw-r--r-- | src/common/utils/Utils.h | 4 |
3 files changed, 58 insertions, 6 deletions
diff --git a/src/common/utils/LegacySupport.cpp b/src/common/utils/LegacySupport.cpp index 5981c657bd..569b2abd89 100644 --- a/src/common/utils/LegacySupport.cpp +++ b/src/common/utils/LegacySupport.cpp @@ -29,7 +29,7 @@ namespace detail { namespace { -DataType data_type_mapper(AclDataType data_type) +DataType convert_to_legacy_data_type(AclDataType data_type) { switch(data_type) { @@ -41,11 +41,25 @@ DataType data_type_mapper(AclDataType data_type) return DataType::BFLOAT16; default: return DataType::UNKNOWN; - ; } } -TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) +AclDataType convert_to_c_data_type(DataType data_type) +{ + switch(data_type) + { + case DataType::F32: + return AclDataType::AclFloat32; + case DataType::F16: + return AclDataType::AclFloat16; + case DataType::BFLOAT16: + return AclDataType::AclBFloat16; + default: + return AclDataType::AclDataTypeUnknown; + } +} + +TensorShape create_legacy_tensor_shape(int32_t ndims, int32_t *shape) { TensorShape legacy_shape{}; for(int32_t d = 0; d < ndims; ++d) @@ -54,13 +68,44 @@ TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape) } return legacy_shape; } +int32_t *create_tensor_shape_array(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + if(num_dims <= 0) + { + return nullptr; + } + + int32_t *shape_array = new int32_t[num_dims]; + + for(size_t d = 0; d < num_dims; ++d) + { + shape_array[d] = info.tensor_shape()[d]; + } + + return shape_array; +} } // namespace TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc) { TensorInfo legacy_desc; - legacy_desc.init(tensor_shape_mapper(desc.ndims, desc.shape), 1, data_type_mapper(desc.data_type)); + legacy_desc.init(create_legacy_tensor_shape(desc.ndims, desc.shape), 1, convert_to_legacy_data_type(desc.data_type)); return legacy_desc; } + +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info) +{ + const auto num_dims = info.num_dimensions(); + AclTensorDescriptor desc + { + static_cast<int32_t>(num_dims), + create_tensor_shape_array(info), + convert_to_c_data_type(info.data_type()), + nullptr, + 0 + }; + return desc; +} } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/LegacySupport.h b/src/common/utils/LegacySupport.h index 37329b747c..c2cc1bc182 100644 --- a/src/common/utils/LegacySupport.h +++ b/src/common/utils/LegacySupport.h @@ -38,6 +38,13 @@ namespace detail * @return Legacy tensor meta-data */ TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc); +/** Convert a legacy tensor meta-data to a descriptor + * + * @param[in] info Legacy tensor meta-data + * + * @return A converted descriptor + */ +AclTensorDescriptor convert_to_descriptor(const TensorInfo &info); } // namespace detail } // namespace arm_compute diff --git a/src/common/utils/Utils.h b/src/common/utils/Utils.h index 87be9df509..79f4f39c47 100644 --- a/src/common/utils/Utils.h +++ b/src/common/utils/Utils.h @@ -40,7 +40,7 @@ namespace utils * @return A corresponding plain old C enumeration */ template <typename E, typename SE> -constexpr E as_cenum(SE v) noexcept +constexpr E as_cenum(const SE v) noexcept { return static_cast<E>(static_cast<std::underlying_type_t<SE>>(v)); } @@ -55,7 +55,7 @@ constexpr E as_cenum(SE v) noexcept * @return A corresponding strongly typed enumeration */ template <typename SE, typename E> -constexpr SE as_enum(E val) noexcept +constexpr SE as_enum(const E val) noexcept { return static_cast<SE>(val); } |