diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /reference_model/src/tensor.cc | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 7cbeb13..cbe12a9 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ #include "tensor.h" #include "arith_util.h" +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -84,6 +85,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) { uint32_t elements = getElementCount(); float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; @@ -97,6 +99,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf); break; + case DType_FP16: + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -146,9 +156,17 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (getDtype()) { + case DType_FP16: + // Convert from fp16 to fp32 + for (uint32_t i=0; i < elements; i++) { + fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]); + } + // Fall through to DType_FLOAT case case DType_FLOAT: if (setTensorValueFloat(elements, fdatabuf)) { + if (f16databuf) + free(f16databuf); free(fdatabuf); return 1; } @@ -187,6 +205,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) if (fdatabuf) free(fdatabuf); + if (f16databuf) + free(f16databuf); if (i32databuf) free(i32databuf); if (i64databuf) @@ -200,11 +220,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int TosaReference::Tensor::writeToNpyFile(const char* filename) const { float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; - int elements = getElementCount(); + uint32_t elements = getElementCount(); switch (getDtype()) { @@ -222,6 +243,27 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); break; + case DType_FP16: + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + + if (getTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + free(f16databuf); + return 1; + } + // Convert fp32 to fp16 + for (uint32_t i=0; i < elements; i++) { + f16databuf[i] = half_float::half_cast<half_float::half, float>(fdatabuf[i]); + } + nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf); + + free(fdatabuf); + free(f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: |