diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2023-06-08 15:59:28 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2023-06-09 13:59:51 +0000 |
commit | 08dfba38b99abcf12db39d6650e4e3758f1bd0b4 (patch) | |
tree | 7152cf10c8a9e031dd3e7314ae5f1021d39a0ad7 /utils/Utils.h | |
parent | 54b41a30b2f05b4d7afe9cb21aea96341d2369ff (diff) | |
download | ComputeLibrary-08dfba38b99abcf12db39d6650e4e3758f1bd0b4.tar.gz |
Enable conversion of F32 NumPy files to F16
Resolves COMPMID-6277
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Change-Id: Ibbe5eb0869f701d4329782101ee7336948350269
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9747
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'utils/Utils.h')
-rw-r--r-- | utils/Utils.h | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/utils/Utils.h b/utils/Utils.h index e3a5bb2c3c..d181022ffe 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -383,7 +383,22 @@ public: // Check if the typestring matches the given one std::string expect_typestr = get_typestring(tensor.info()->data_type()); - ARM_COMPUTE_ERROR_ON_MSG(_typestring != expect_typestr, "Typestrings mismatch"); + + bool enable_f32_to_f16_conversion = false; + if(_typestring != expect_typestr) + { + const std::string f32_typestring = "<f4"; + const std::string f16_typestring = "<f2"; + // if typestring does not match, check whether _typestring is F32 and can be downcasted to expect_typestr + if(_typestring == f32_typestring && expect_typestr == f16_typestring) + { + enable_f32_to_f16_conversion = true; + } + else + { + ARM_COMPUTE_ERROR("Typestrings mismatch"); + } + } bool are_layouts_different = (_file_layout != tensor.info()->data_layout()); // Correct dimensions (Needs to match TensorShape dimension corrections) @@ -427,7 +442,7 @@ public: case arm_compute::DataType::F16: { // Read data - if(!are_layouts_different && !_fortran_order && tensor.info()->padding().empty()) + if(!are_layouts_different && !_fortran_order && tensor.info()->padding().empty() && !enable_f32_to_f16_conversion) { // If tensor has no padding read directly from stream. _fs.read(reinterpret_cast<char *>(tensor.buffer()), tensor.info()->total_size()); @@ -466,7 +481,17 @@ public: { Coordinates dst(id); arm_compute::permute(dst, perm); - _fs.read(reinterpret_cast<char *>(tensor.ptr_to_element(dst)), tensor.info()->element_size()); + if(enable_f32_to_f16_conversion) + { + float f32_val = 0; + _fs.read(reinterpret_cast<char *>(&f32_val), 4u); + half f16_val = half_float::half_cast<half, std::round_to_nearest>(f32_val); + *(reinterpret_cast<half *>(tensor.ptr_to_element(dst))) = f16_val; + } + else + { + _fs.read(reinterpret_cast<char *>(tensor.ptr_to_element(dst)), tensor.info()->element_size()); + } }); } |