diff options
author | Fabrizio Indirli <Fabrizio.Indirli@arm.com> | 2023-12-11 11:15:32 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-01-11 19:09:55 +0000 |
commit | 7203835d65afb8a06a5cf98072718d93a9b71567 (patch) | |
tree | 7e625e5489eee34e5b6ec2d24e00102799e57344 /reference_model/src/tensor.cc | |
parent | 48ed6cf8eb0d6beab4cb97f08fc41e037bfd182e (diff) | |
download | reference_model-7203835d65afb8a06a5cf98072718d93a9b71567.tar.gz |
Add support for precise mode in eager runner
Add support for Fp64 tensors in the eager runner's
helper functions, when precise mode is enabled.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
Change-Id: Ib737c0d18fb1c7ac40ce6ea03a4fbcefae88ba5c
Diffstat (limited to 'reference_model/src/tensor.cc')
-rw-r--r-- | reference_model/src/tensor.cc | 164 |
1 files changed, 162 insertions, 2 deletions
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 645b55f..e84507b 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -580,6 +580,14 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals) uint32_t elements = getElementCount(); switch (getDtype()) { + case TOSA_REF_TYPE_FP64: + if (!g_func_config.precise_mode) + { + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: if (vals.size() != elements) @@ -622,6 +630,14 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> val switch (getDtype()) { + case TOSA_REF_TYPE_FP64: + if (!g_func_config.precise_mode) + { + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: if (vals.size() != elements) { @@ -953,7 +969,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals) template <class T> int TosaReference::TensorTemplate<T>::setTensorValueDouble(const size_t buflen, const double* vals) { - FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. " + FATAL_ERROR("TensorTemplate<T>::setTensorValueDouble should not be called. " "Implement template specialization version."); return 0; } @@ -1254,6 +1270,150 @@ int TosaReference::Tensor6<float>::setTensorValueFloat(const size_t bufLen, cons return 0; } +template <> +int TosaReference::Tensor0<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<double>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + template <class T> int TosaReference::TensorTemplate<T>::setTensorValueInt16(const size_t bufLen, const int16_t* vals) { |