aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc46
1 files changed, 44 insertions, 2 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 7cbeb13..cbe12a9 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
#include "tensor.h"
#include "arith_util.h"
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -84,6 +85,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
{
uint32_t elements = getElementCount();
float* fdatabuf = nullptr;
+ half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
@@ -97,6 +99,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf);
break;
+ case DType_FP16:
+ f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
+ ASSERT_MEM(f16databuf);
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf);
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4:
@@ -146,9 +156,17 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
switch (getDtype())
{
+ case DType_FP16:
+ // Convert from fp16 to fp32
+ for (uint32_t i=0; i < elements; i++) {
+ fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
+ }
+ // Fall through to DType_FLOAT case
case DType_FLOAT:
if (setTensorValueFloat(elements, fdatabuf))
{
+ if (f16databuf)
+ free(f16databuf);
free(fdatabuf);
return 1;
}
@@ -187,6 +205,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
if (fdatabuf)
free(fdatabuf);
+ if (f16databuf)
+ free(f16databuf);
if (i32databuf)
free(i32databuf);
if (i64databuf)
@@ -200,11 +220,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
float* fdatabuf = nullptr;
+ half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
- int elements = getElementCount();
+ uint32_t elements = getElementCount();
switch (getDtype())
{
@@ -222,6 +243,27 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(fdatabuf);
break;
+ case DType_FP16:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+ f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
+ ASSERT_MEM(f16databuf);
+
+ if (getTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ free(f16databuf);
+ return 1;
+ }
+ // Convert fp32 to fp16
+ for (uint32_t i=0; i < elements; i++) {
+ f16databuf[i] = half_float::half_cast<half_float::half, float>(fdatabuf[i]);
+ }
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf);
+
+ free(fdatabuf);
+ free(f16databuf);
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4: