aboutsummaryrefslogtreecommitdiff
path: root/src/common/utils
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/utils')
-rw-r--r--src/common/utils/LegacySupport.cpp53
-rw-r--r--src/common/utils/LegacySupport.h7
-rw-r--r--src/common/utils/Utils.h4
3 files changed, 58 insertions, 6 deletions
diff --git a/src/common/utils/LegacySupport.cpp b/src/common/utils/LegacySupport.cpp
index 5981c657bd..569b2abd89 100644
--- a/src/common/utils/LegacySupport.cpp
+++ b/src/common/utils/LegacySupport.cpp
@@ -29,7 +29,7 @@ namespace detail
{
namespace
{
-DataType data_type_mapper(AclDataType data_type)
+DataType convert_to_legacy_data_type(AclDataType data_type)
{
switch(data_type)
{
@@ -41,11 +41,25 @@ DataType data_type_mapper(AclDataType data_type)
return DataType::BFLOAT16;
default:
return DataType::UNKNOWN;
- ;
}
}
-TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape)
+AclDataType convert_to_c_data_type(DataType data_type)
+{
+ switch(data_type)
+ {
+ case DataType::F32:
+ return AclDataType::AclFloat32;
+ case DataType::F16:
+ return AclDataType::AclFloat16;
+ case DataType::BFLOAT16:
+ return AclDataType::AclBFloat16;
+ default:
+ return AclDataType::AclDataTypeUnknown;
+ }
+}
+
+TensorShape create_legacy_tensor_shape(int32_t ndims, int32_t *shape)
{
TensorShape legacy_shape{};
for(int32_t d = 0; d < ndims; ++d)
@@ -54,13 +68,44 @@ TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape)
}
return legacy_shape;
}
+int32_t *create_tensor_shape_array(const TensorInfo &info)
+{
+ const auto num_dims = info.num_dimensions();
+ if(num_dims <= 0)
+ {
+ return nullptr;
+ }
+
+ int32_t *shape_array = new int32_t[num_dims];
+
+ for(size_t d = 0; d < num_dims; ++d)
+ {
+ shape_array[d] = info.tensor_shape()[d];
+ }
+
+ return shape_array;
+}
} // namespace
TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc)
{
TensorInfo legacy_desc;
- legacy_desc.init(tensor_shape_mapper(desc.ndims, desc.shape), 1, data_type_mapper(desc.data_type));
+ legacy_desc.init(create_legacy_tensor_shape(desc.ndims, desc.shape), 1, convert_to_legacy_data_type(desc.data_type));
return legacy_desc;
}
+
+AclTensorDescriptor convert_to_descriptor(const TensorInfo &info)
+{
+ const auto num_dims = info.num_dimensions();
+ AclTensorDescriptor desc
+ {
+ static_cast<int32_t>(num_dims),
+ create_tensor_shape_array(info),
+ convert_to_c_data_type(info.data_type()),
+ nullptr,
+ 0
+ };
+ return desc;
+}
} // namespace detail
} // namespace arm_compute
diff --git a/src/common/utils/LegacySupport.h b/src/common/utils/LegacySupport.h
index 37329b747c..c2cc1bc182 100644
--- a/src/common/utils/LegacySupport.h
+++ b/src/common/utils/LegacySupport.h
@@ -38,6 +38,13 @@ namespace detail
* @return Legacy tensor meta-data
*/
TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc);
+/** Convert a legacy tensor meta-data to a descriptor
+ *
+ * @param[in] info Legacy tensor meta-data
+ *
+ * @return A converted descriptor
+ */
+AclTensorDescriptor convert_to_descriptor(const TensorInfo &info);
} // namespace detail
} // namespace arm_compute
diff --git a/src/common/utils/Utils.h b/src/common/utils/Utils.h
index 87be9df509..79f4f39c47 100644
--- a/src/common/utils/Utils.h
+++ b/src/common/utils/Utils.h
@@ -40,7 +40,7 @@ namespace utils
* @return A corresponding plain old C enumeration
*/
template <typename E, typename SE>
-constexpr E as_cenum(SE v) noexcept
+constexpr E as_cenum(const SE v) noexcept
{
return static_cast<E>(static_cast<std::underlying_type_t<SE>>(v));
}
@@ -55,7 +55,7 @@ constexpr E as_cenum(SE v) noexcept
* @return A corresponding strongly typed enumeration
*/
template <typename SE, typename E>
-constexpr SE as_enum(E val) noexcept
+constexpr SE as_enum(const E val) noexcept
{
return static_cast<SE>(val);
}