aboutsummaryrefslogtreecommitdiff
path: root/src/common/utils/LegacySupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/utils/LegacySupport.cpp')
-rw-r--r--src/common/utils/LegacySupport.cpp53
1 files changed, 49 insertions, 4 deletions
diff --git a/src/common/utils/LegacySupport.cpp b/src/common/utils/LegacySupport.cpp
index 5981c657bd..569b2abd89 100644
--- a/src/common/utils/LegacySupport.cpp
+++ b/src/common/utils/LegacySupport.cpp
@@ -29,7 +29,7 @@ namespace detail
{
namespace
{
-DataType data_type_mapper(AclDataType data_type)
+DataType convert_to_legacy_data_type(AclDataType data_type)
{
switch(data_type)
{
@@ -41,11 +41,25 @@ DataType data_type_mapper(AclDataType data_type)
return DataType::BFLOAT16;
default:
return DataType::UNKNOWN;
- ;
}
}
-TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape)
+AclDataType convert_to_c_data_type(DataType data_type)
+{
+ switch(data_type)
+ {
+ case DataType::F32:
+ return AclDataType::AclFloat32;
+ case DataType::F16:
+ return AclDataType::AclFloat16;
+ case DataType::BFLOAT16:
+ return AclDataType::AclBFloat16;
+ default:
+ return AclDataType::AclDataTypeUnknown;
+ }
+}
+
+TensorShape create_legacy_tensor_shape(int32_t ndims, int32_t *shape)
{
TensorShape legacy_shape{};
for(int32_t d = 0; d < ndims; ++d)
@@ -54,13 +68,44 @@ TensorShape tensor_shape_mapper(int32_t ndims, int32_t *shape)
}
return legacy_shape;
}
+int32_t *create_tensor_shape_array(const TensorInfo &info)
+{
+ const auto num_dims = info.num_dimensions();
+ if(num_dims <= 0)
+ {
+ return nullptr;
+ }
+
+ int32_t *shape_array = new int32_t[num_dims];
+
+ for(size_t d = 0; d < num_dims; ++d)
+ {
+ shape_array[d] = info.tensor_shape()[d];
+ }
+
+ return shape_array;
+}
} // namespace
TensorInfo convert_to_legacy_tensor_info(const AclTensorDescriptor &desc)
{
TensorInfo legacy_desc;
- legacy_desc.init(tensor_shape_mapper(desc.ndims, desc.shape), 1, data_type_mapper(desc.data_type));
+ legacy_desc.init(create_legacy_tensor_shape(desc.ndims, desc.shape), 1, convert_to_legacy_data_type(desc.data_type));
return legacy_desc;
}
+
+AclTensorDescriptor convert_to_descriptor(const TensorInfo &info)
+{
+ const auto num_dims = info.num_dimensions();
+ AclTensorDescriptor desc
+ {
+ static_cast<int32_t>(num_dims),
+ create_tensor_shape_array(info),
+ convert_to_c_data_type(info.data_type()),
+ nullptr,
+ 0
+ };
+ return desc;
+}
} // namespace detail
} // namespace arm_compute