aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r--reference_model/src/tensor.cc63
1 files changed, 28 insertions, 35 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index c5f5e02..4982c99 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -30,11 +30,11 @@ TosaReference::Tensor::Tensor(const std::string tensorName_,
, shape(shape_)
, tensorDtype(ConvertDType(serializationDtype_))
{
- producer = nullptr;
- isValid = false;
+ producer = nullptr;
+ isValid = false;
consumers.clear();
- isSubgraphInput = false;
- isSubgraphOutput = false;
+ isSubgraphInput = false;
+ isSubgraphOutput = false;
isParentGraphOutput = false;
}
@@ -94,16 +94,16 @@ int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
int TosaReference::Tensor::readFromNpyFile(const char* filename)
{
- uint32_t elements = getElementCount();
+ uint32_t elements = getElementCount();
double* f64databuf = nullptr;
float* f32databuf = nullptr;
half_float::half* f16databuf = nullptr;
- int32_t* i32databuf = nullptr;
- int64_t* i64databuf = nullptr;
- bool* bdatabuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
- TOSA_REF_TYPE dtype = getDtype();
- DType serialization_dtype = getSerializationDtype();
+ TOSA_REF_TYPE dtype = getDtype();
+ DType serialization_dtype = getSerializationDtype();
assert(dtype == ConvertDType(serialization_dtype));
// if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16
@@ -178,7 +178,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
// Convert from fp16 to fp32 so that fp16 values can be manipulated as float
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
- for (uint32_t i=0; i < elements; i++) {
+ for (uint32_t i = 0; i < elements; i++)
+ {
f32databuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
if (setTensorValueFloat(elements, f32databuf))
@@ -189,12 +190,9 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
}
break;
case TOSA_REF_TYPE_BF16:
- for (uint32_t i=0; i < elements; i++)
+ for (uint32_t i = 0; i < elements; i++)
{
- ASSERT_MSG(
- checkValidBFloat(f32databuf[i]),
- "Input float value not a valid bfloat16 value."
- );
+ ASSERT_MSG(checkValidBFloat(f32databuf[i]), "Input float value not a valid bfloat16 value.");
}
if (setTensorValueFloat(elements, f32databuf))
{
@@ -313,15 +311,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
- float* f32databuf = nullptr;
- double* f64databuf = nullptr;
- half_float::half* f16databuf = nullptr;
- int32_t* i32databuf = nullptr;
- int64_t* i64databuf = nullptr;
- bool* bdatabuf = nullptr;
+ float* f32databuf = nullptr;
+ double* f64databuf = nullptr;
+ half_float::half* f16databuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror = NumpyUtilities::NO_ERROR;
- uint32_t elements = getElementCount();
- const TOSA_REF_TYPE dtype = getDtype();
+ uint32_t elements = getElementCount();
+ const TOSA_REF_TYPE dtype = getDtype();
switch (dtype)
{
@@ -352,7 +350,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
return 1;
}
// Convert fp32 to fp16 so that output file contains valid fp16 data
- for (uint32_t i=0; i < elements; i++) {
+ for (uint32_t i = 0; i < elements; i++)
+ {
f16databuf[i] = half_float::half_cast<half_float::half, float>(f32databuf[i]);
}
nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf);
@@ -592,10 +591,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
for (auto v : vals)
{
- ASSERT_MSG(
- checkValidBFloat(v),
- "Input float value not a valid bfloat16 value."
- );
+ ASSERT_MSG(checkValidBFloat(v), "Input float value not a valid bfloat16 value.");
}
setTensorValueFloat(elements, vals.data());
@@ -625,7 +621,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> val
}
// Convert from fp16 to fp32
- for (uint32_t i=0; i < elements; i++)
+ for (uint32_t i = 0; i < elements; i++)
{
tensor[i] = half_float::half_cast<float, half_float::half>(vals[i]);
}
@@ -772,10 +768,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
for (auto v : vals)
{
- ASSERT_MSG(
- checkValidBFloat(v),
- "Float value not a valid bfloat16 value."
- );
+ ASSERT_MSG(checkValidBFloat(v), "Float value not a valid bfloat16 value.");
}
break;
@@ -805,7 +798,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals)
getTensorValueFloat(elements, tensor.data());
// Convert fp32 to fp16
- for (uint32_t i=0; i < elements; i++)
+ for (uint32_t i = 0; i < elements; i++)
{
vals[i] = half_float::half_cast<half_float::half, float>(tensor[i]);
}