diff options
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 63 |
1 files changed, 28 insertions, 35 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index c5f5e02..4982c99 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -30,11 +30,11 @@ TosaReference::Tensor::Tensor(const std::string tensorName_, , shape(shape_) , tensorDtype(ConvertDType(serializationDtype_)) { - producer = nullptr; - isValid = false; + producer = nullptr; + isValid = false; consumers.clear(); - isSubgraphInput = false; - isSubgraphOutput = false; + isSubgraphInput = false; + isSubgraphOutput = false; isParentGraphOutput = false; } @@ -94,16 +94,16 @@ int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const int TosaReference::Tensor::readFromNpyFile(const char* filename) { - uint32_t elements = getElementCount(); + uint32_t elements = getElementCount(); double* f64databuf = nullptr; float* f32databuf = nullptr; half_float::half* f16databuf = nullptr; - int32_t* i32databuf = nullptr; - int64_t* i64databuf = nullptr; - bool* bdatabuf = nullptr; + int32_t* i32databuf = nullptr; + int64_t* i64databuf = nullptr; + bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; - TOSA_REF_TYPE dtype = getDtype(); - DType serialization_dtype = getSerializationDtype(); + TOSA_REF_TYPE dtype = getDtype(); + DType serialization_dtype = getSerializationDtype(); assert(dtype == ConvertDType(serialization_dtype)); // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 @@ -178,7 +178,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) // Convert from fp16 to fp32 so that fp16 values can be manipulated as float f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); - for (uint32_t i=0; i < elements; i++) { + for (uint32_t i = 0; i < elements; i++) + { f32databuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]); } if (setTensorValueFloat(elements, f32databuf)) @@ -189,12 +190,9 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) } break; case TOSA_REF_TYPE_BF16: - for (uint32_t i=0; i < elements; i++) + for (uint32_t i = 0; i < elements; i++) { - ASSERT_MSG( - checkValidBFloat(f32databuf[i]), - "Input float value not a valid bfloat16 value." - ); + ASSERT_MSG(checkValidBFloat(f32databuf[i]), "Input float value not a valid bfloat16 value."); } if (setTensorValueFloat(elements, f32databuf)) { @@ -313,15 +311,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int TosaReference::Tensor::writeToNpyFile(const char* filename) const { - float* f32databuf = nullptr; - double* f64databuf = nullptr; - half_float::half* f16databuf = nullptr; - int32_t* i32databuf = nullptr; - int64_t* i64databuf = nullptr; - bool* bdatabuf = nullptr; + float* f32databuf = nullptr; + double* f64databuf = nullptr; + half_float::half* f16databuf = nullptr; + int32_t* i32databuf = nullptr; + int64_t* i64databuf = nullptr; + bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror = NumpyUtilities::NO_ERROR; - uint32_t elements = getElementCount(); - const TOSA_REF_TYPE dtype = getDtype(); + uint32_t elements = getElementCount(); + const TOSA_REF_TYPE dtype = getDtype(); switch (dtype) { @@ -352,7 +350,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const return 1; } // Convert fp32 to fp16 so that output file contains valid fp16 data - for (uint32_t i=0; i < elements; i++) { + for (uint32_t i = 0; i < elements; i++) + { f16databuf[i] = half_float::half_cast<half_float::half, float>(f32databuf[i]); } nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf); @@ -592,10 +591,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) for (auto v : vals) { - ASSERT_MSG( - checkValidBFloat(v), - "Input float value not a valid bfloat16 value." - ); + ASSERT_MSG(checkValidBFloat(v), "Input float value not a valid bfloat16 value."); } setTensorValueFloat(elements, vals.data()); @@ -625,7 +621,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> val } // Convert from fp16 to fp32 - for (uint32_t i=0; i < elements; i++) + for (uint32_t i = 0; i < elements; i++) { tensor[i] = half_float::half_cast<float, half_float::half>(vals[i]); } @@ -772,10 +768,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals) for (auto v : vals) { - ASSERT_MSG( - checkValidBFloat(v), - "Float value not a valid bfloat16 value." - ); + ASSERT_MSG(checkValidBFloat(v), "Float value not a valid bfloat16 value."); } break; @@ -805,7 +798,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals) getTensorValueFloat(elements, tensor.data()); // Convert fp32 to fp16 - for (uint32_t i=0; i < elements; i++) + for (uint32_t i = 0; i < elements; i++) { vals[i] = half_float::half_cast<half_float::half, float>(tensor[i]); } |