diff options
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r-- | reference_model/src/tensor.h | 209 |
1 files changed, 157 insertions, 52 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d5f1de8..08ee8bf 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -17,9 +17,9 @@ #define TOSA_REFERENCE_TENSOR_H #include "array_proxy.h" +#include "dtype.h" #include "model_common.h" #include "ops/template_types.h" -#include "tosa_generated.h" #include "tosa_serialization_handler.h" #include <Eigen/CXX11/Tensor> #include <list> @@ -34,7 +34,7 @@ class GraphNode; class Tensor { public: - Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_); + Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_); virtual ~Tensor(); @@ -212,19 +212,26 @@ public: return shape.size(); } - const DType getDtype() const + const TOSA_REF_TYPE getDtype() const { return tensorDtype; } + const DType getSerializationDtype() const + { + return serializationDtype; + } + virtual int dumpTensor(FILE* out) const = 0; virtual int dumpTensorParams(FILE* out) const; virtual int dumpTensorParams(std::ostream& out) const; + virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0; virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0; virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0; virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0; virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0; + virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0; @@ -234,12 +241,14 @@ public: virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; + virtual int readfromVector(const ArrayProxy<double> vals); virtual int readfromVector(const ArrayProxy<float> vals); virtual int readfromVector(const ArrayProxy<half_float::half> vals); virtual int readfromVector(const ArrayProxy<int32_t> vals); virtual int readfromVector(const ArrayProxy<int64_t> vals); virtual int readfromVector(const ArrayProxy<unsigned char> vals); + virtual int writeToVector(ArrayProxy<double> vals); virtual int writeToVector(ArrayProxy<float> vals); virtual int writeToVector(ArrayProxy<half_float::half> vals); virtual int writeToVector(ArrayProxy<int32_t> vals); @@ -258,10 +267,11 @@ public: virtual bool is_allocated() = 0; protected: - std::string tensorName; - DType tensorDtype; + const std::string tensorName; + const DType serializationDtype; + const std::vector<int> shape; + const TOSA_REF_TYPE tensorDtype; int isValid; - std::vector<int> shape; int isSubgraphInput; int isSubgraphOutput; bool isAllocated; @@ -284,8 +294,8 @@ template <class T> class TensorTemplate : public Tensor { public: - TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_) - : Tensor(tensorName_, tensorDtype_, shape_) + TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_) + : Tensor(tensorName_, dtype_, shape_) { tensor = nullptr; } @@ -330,10 +340,13 @@ public: virtual int dumpTensor(FILE* out) const; + virtual int setTensorValueDouble(const size_t bufLen, const double* vals); virtual int setTensorValueFloat(const size_t bufLen, const float* vals); virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals); virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals); virtual int setTensorValueBool(const size_t bufLen, const bool* vals); + + virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const; @@ -363,6 +376,21 @@ template <> int Tensor6<float>::allocate(); template <> +int Tensor0<double>::allocate(); +template <> +int Tensor1<double>::allocate(); +template <> +int Tensor2<double>::allocate(); +template <> +int Tensor3<double>::allocate(); +template <> +int Tensor4<double>::allocate(); +template <> +int Tensor5<double>::allocate(); +template <> +int Tensor6<double>::allocate(); + +template <> int Tensor0<int32_t>::allocate(); template <> int Tensor1<int32_t>::allocate(); @@ -423,6 +451,21 @@ template <> int Tensor6<float>::copyValueFrom(Tensor* src); template <> +int Tensor0<double>::copyValueFrom(Tensor* src); +template <> +int Tensor1<double>::copyValueFrom(Tensor* src); +template <> +int Tensor2<double>::copyValueFrom(Tensor* src); +template <> +int Tensor3<double>::copyValueFrom(Tensor* src); +template <> +int Tensor4<double>::copyValueFrom(Tensor* src); +template <> +int Tensor5<double>::copyValueFrom(Tensor* src); +template <> +int Tensor6<double>::copyValueFrom(Tensor* src); + +template <> int Tensor0<int32_t>::copyValueFrom(Tensor* src); template <> int Tensor1<int32_t>::copyValueFrom(Tensor* src); @@ -558,6 +601,36 @@ template <> int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> +int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals); + +template <> +int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const; + +template <> int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); @@ -587,7 +660,6 @@ int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; -// assume we only dump float type tensor now template <> int Tensor0<float>::dumpTensor(FILE* out) const; template <> @@ -603,6 +675,20 @@ int Tensor5<float>::dumpTensor(FILE* out) const; template <> int Tensor6<float>::dumpTensor(FILE* out) const; template <> +int Tensor0<double>::dumpTensor(FILE* out) const; +template <> +int Tensor1<double>::dumpTensor(FILE* out) const; +template <> +int Tensor2<double>::dumpTensor(FILE* out) const; +template <> +int Tensor3<double>::dumpTensor(FILE* out) const; +template <> +int Tensor4<double>::dumpTensor(FILE* out) const; +template <> +int Tensor5<float>::dumpTensor(FILE* out) const; +template <> +int Tensor6<double>::dumpTensor(FILE* out) const; +template <> int Tensor0<int32_t>::dumpTensor(FILE* out) const; template <> int Tensor1<int32_t>::dumpTensor(FILE* out) const; @@ -648,100 +734,119 @@ int Tensor6<bool>::dumpTensor(FILE* out) const; class TensorFactory { public: - static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank) + static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank) { + TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_); switch (tensorDtype_) { - case DType_FP32: - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + switch (rank) + { + case 0: + return new Tensor0<float>(tensorName_, dtype_, shape_); + case 1: + return new Tensor1<float>(tensorName_, dtype_, shape_); + case 2: + return new Tensor2<float>(tensorName_, dtype_, shape_); + case 3: + return new Tensor3<float>(tensorName_, dtype_, shape_); + case 4: + return new Tensor4<float>(tensorName_, dtype_, shape_); + case 5: + return new Tensor5<float>(tensorName_, dtype_, shape_); + case 6: + return new Tensor6<float>(tensorName_, dtype_, shape_); + } + break; + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: switch (rank) { case 0: - return new Tensor0<float>(tensorName_, tensorDtype_, shape_); + return new Tensor0<int32_t>(tensorName_, dtype_, shape_); case 1: - return new Tensor1<float>(tensorName_, tensorDtype_, shape_); + return new Tensor1<int32_t>(tensorName_, dtype_, shape_); case 2: - return new Tensor2<float>(tensorName_, tensorDtype_, shape_); + return new Tensor2<int32_t>(tensorName_, dtype_, shape_); case 3: - return new Tensor3<float>(tensorName_, tensorDtype_, shape_); + return new Tensor3<int32_t>(tensorName_, dtype_, shape_); case 4: - return new Tensor4<float>(tensorName_, tensorDtype_, shape_); + return new Tensor4<int32_t>(tensorName_, dtype_, shape_); case 5: - return new Tensor5<float>(tensorName_, tensorDtype_, shape_); + return new Tensor5<int32_t>(tensorName_, dtype_, shape_); case 6: - return new Tensor6<float>(tensorName_, tensorDtype_, shape_); + return new Tensor6<int32_t>(tensorName_, dtype_, shape_); } break; - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT48: switch (rank) { case 0: - return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor0<int64_t>(tensorName_, dtype_, shape_); case 1: - return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor1<int64_t>(tensorName_, dtype_, shape_); case 2: - return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor2<int64_t>(tensorName_, dtype_, shape_); case 3: - return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor3<int64_t>(tensorName_, dtype_, shape_); case 4: - return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor4<int64_t>(tensorName_, dtype_, shape_); case 5: - return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor5<int64_t>(tensorName_, dtype_, shape_); case 6: - return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_); + return new Tensor6<int64_t>(tensorName_, dtype_, shape_); } break; - case DType_INT48: + case TOSA_REF_TYPE_BOOL: switch (rank) { case 0: - return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor0<bool>(tensorName_, dtype_, shape_); case 1: - return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor1<bool>(tensorName_, dtype_, shape_); case 2: - return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor2<bool>(tensorName_, dtype_, shape_); case 3: - return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor3<bool>(tensorName_, dtype_, shape_); case 4: - return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor4<bool>(tensorName_, dtype_, shape_); case 5: - return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor5<bool>(tensorName_, dtype_, shape_); case 6: - return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_); + return new Tensor6<bool>(tensorName_, dtype_, shape_); } break; - case DType_BOOL: + case TOSA_REF_TYPE_FP64: switch (rank) { case 0: - return new Tensor0<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor0<double>(tensorName_, dtype_, shape_); case 1: - return new Tensor1<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor1<double>(tensorName_, dtype_, shape_); case 2: - return new Tensor2<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor2<double>(tensorName_, dtype_, shape_); case 3: - return new Tensor3<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor3<double>(tensorName_, dtype_, shape_); case 4: - return new Tensor4<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor4<double>(tensorName_, dtype_, shape_); case 5: - return new Tensor5<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor5<double>(tensorName_, dtype_, shape_); case 6: - return new Tensor6<bool>(tensorName_, tensorDtype_, shape_); + return new Tensor6<double>(tensorName_, dtype_, shape_); } break; - default: + case TOSA_REF_TYPE_UNKNOWN: + assert(0); // tensorDtype_ is uninitialized break; } return nullptr; } - - static Tensor* newTensor(DType type, const std::vector<int> shape); }; }; // namespace TosaReference |