// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef TOSA_REFERENCE_TENSOR_H #define TOSA_REFERENCE_TENSOR_H #include "array_proxy.h" #include "dtype.h" #include "model_common.h" #include "ops/template_types.h" #include "tosa_serialization_handler.h" #include #include #include using namespace tosa; namespace TosaReference { class GraphNode; class Tensor { public: Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector shape_); virtual ~Tensor(); int setIsSubgraphInput(); int setIsSubgraphOutput(); int setIsParentGraphOutput(); bool getIsParentGraphOutput() const { return isParentGraphOutput; } int setIsVariable(); bool getIsSubgraphInput() const { return isSubgraphInput; } bool getIsSubgraphOutput() const { return isSubgraphOutput; } bool getIsVariable() const { return isVariable; } int setProducer(GraphNode* node); int addConsumer(GraphNode* node); int setIsValid() { isValid = 1; return 0; } int clearIsValid() { isValid = 0; return 0; } int getIsValid() const { return isValid; } GraphNode* getProducer() { return producer; } std::vector& getConsumers() { return consumers; } const std::string& getName() const { return tensorName; } const std::vector& getShape() const { return shape; } void setDimSize(size_t dim, uint32_t new_size) { this->shape[dim] = new_size; return; } void setShapeValue(std::vector& shapeValue) { for (auto dim : shapeValue) { this->shapeValue.push_back(dim); } return; } int getShapeValueSize() const { return this->shapeValue.size(); } std::string getShapeValueAsString() const { std::string shape_str("["); for (auto& dim : shapeValue) { shape_str += (std::to_string(dim) + ", "); } shape_str.append("]"); return shape_str; } std::string getShapeAsString() const { std::string shape_str("["); for (auto& dim : shape) { shape_str += (std::to_string(dim) + ", "); } shape_str.append("]"); return shape_str; } const uint32_t getElementCount() const { uint32_t elements = 1; for (size_t i = 0; i < shape.size(); i++) elements *= shape[i]; return elements; } // Comparison of rank and type with other tensors const int matchRank(const Tensor& ref) const { return (ref.shape.size() == shape.size()) ? 0 : 1; } const int matchType(const Tensor& ref) const { return (ref.tensorDtype == tensorDtype) ? 0 : 1; } const int matchRankType(const Tensor& ref) const { return (matchType(ref) || matchRank(ref)); } const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const { if (matchRankType(ref)) return 1; for (size_t i = 0; i < shape.size(); i++) { if (shape[i] != ref.shape[i]) { if (!broadcastOk || // For broadcasts, the order of *this and ref matters. // *this should be the source tensor. // ref should be the target tensor. In most of the case, ref is expected to be the output tensor. // this->shape must have size 1 if they don't match (broadcastOk && (shape[i] != 1))) { return 1; } } } return 0; } const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const { if (matchRank(ref)) return 1; for (size_t i = 0; i < shape.size(); i++) { if (shape[i] != ref.shape[i]) { if (!broadcastOk || // For broadcasts, the order of *this and ref matters. // *this should be the source tensor. // ref should be the target tensor. In most of the case, ref is expected to be the output tensor. // this->shape must have size 1 if they don't match (broadcastOk && (shape[i] != 1))) { return 1; } } } return 0; } // Sometimes we might want to match several semi-compatible types, // so just check rank and size here const int matchRankSize(const Tensor& ref) const { if (matchRank(ref)) return 1; for (size_t i = 0; i < shape.size(); i++) { if (shape[i] != ref.shape[i]) return 1; } return 0; } // Unary check to make sure rank matches const int checkRequiredRank(const int minRank) const { return (shape.size() >= (size_t)minRank) ? 0 : 1; } const int checkRequiredRank(const int minRank, const int maxRank) const { return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1; } const int getRank() const { return shape.size(); } const TOSA_REF_TYPE getDtype() const { return tensorDtype; } const DType getSerializationDtype() const { return serializationDtype; } virtual int dumpTensor(FILE* out) const = 0; virtual int dumpTensorParams(FILE* out) const; virtual int dumpTensorParams(std::ostream& out) const; virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0; virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0; virtual int setTensorValueUInt8(const size_t bufLen, const uint8_t* vals) = 0; virtual int setTensorValueInt8(const size_t bufLen, const int8_t* vals) = 0; virtual int setTensorValueUInt16(const size_t bufLen, const uint16_t* vals) = 0; virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals) = 0; virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0; virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0; virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0; virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0; virtual int getTensorValueUInt8(const size_t bufLen, uint8_t* ibuf) const = 0; virtual int getTensorValueInt8(const size_t bufLen, int8_t* ibuf) const = 0; virtual int getTensorValueUInt16(const size_t bufLen, uint16_t* ibuf) const = 0; virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const = 0; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0; virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0; virtual int readFromNpyFile(const char* filename); virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); const char* bool_to_str(bool in) const { static const char* true_str = "true"; static const char* false_str = "false"; return in ? true_str : false_str; } virtual int allocate() = 0; virtual int deallocate() = 0; virtual bool is_allocated() const = 0; protected: const std::string tensorName; const DType serializationDtype; std::vector shape; std::vector shapeValue; const TOSA_REF_TYPE tensorDtype; bool isValid; bool isSubgraphInput; bool isSubgraphOutput; bool isVariable; bool isAllocated; bool isParentGraphOutput; GraphNode* producer; std::vector consumers; // Note: the Eigen::Tensor is not declared in Tensor // Instead, the TensorTemplate class keeps the templated tensor // declaration so that the graph manipulation tools are isolated // from the templated tensor type. // // Operators need to be aware of the TensorTemplate> type // so that they can operate on the right types. }; template class TensorTemplate : public Tensor { public: TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector shape_) : Tensor(tensorName_, dtype_, shape_) { tensor = nullptr; } virtual ~TensorTemplate() { deallocate(); } virtual int allocate() { tensor = new T(); if (tensor) return 0; else return 1; } virtual int deallocate() { if (tensor) { DEBUG_INFO(GT, "Deallocating tensor %s", tensorName.c_str()); delete tensor; } tensor = nullptr; return 0; } virtual bool is_allocated() const { if (tensor) { return true; } return false; } T& getTensor() { return *tensor; } virtual int dumpTensor(FILE* out) const; virtual int setTensorValueDouble(const size_t bufLen, const double* vals); virtual int setTensorValueFloat(const size_t bufLen, const float* vals); virtual int setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); virtual int setTensorValueInt8(const size_t bufLen, const int8_t* vals); virtual int setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); virtual int setTensorValueInt16(const size_t bufLen, const int16_t* vals); virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals); virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals); virtual int setTensorValueBool(const size_t bufLen, const bool* vals); virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const; virtual int getTensorValueUInt8(const size_t bufLen, uint8_t* ibuf) const; virtual int getTensorValueInt8(const size_t bufLen, int8_t* ibuf) const; virtual int getTensorValueUInt16(const size_t bufLen, uint16_t* ibuf) const; virtual int getTensorValueInt16(const size_t bufLen, int16_t* ibuf) const; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const; virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const; virtual int copyValueFrom(Tensor* tensor); protected: T* tensor; }; // allocate() template specializations to allocate the different tensor sizes // Let the compiler know here before the factory uses them, but define them in the .cc file. template <> int Tensor0::allocate(); template <> int Tensor1::allocate(); template <> int Tensor2::allocate(); template <> int Tensor3::allocate(); template <> int Tensor4::allocate(); template <> int Tensor5::allocate(); template <> int Tensor6::allocate(); template <> int Tensor0::allocate(); template <> int Tensor1::allocate(); template <> int Tensor2::allocate(); template <> int Tensor3::allocate(); template <> int Tensor4::allocate(); template <> int Tensor5::allocate(); template <> int Tensor6::allocate(); template <> int Tensor0::allocate(); template <> int Tensor1::allocate(); template <> int Tensor2::allocate(); template <> int Tensor3::allocate(); template <> int Tensor4::allocate(); template <> int Tensor5::allocate(); template <> int Tensor6::allocate(); template <> int Tensor0::allocate(); template <> int Tensor1::allocate(); template <> int Tensor2::allocate(); template <> int Tensor3::allocate(); template <> int Tensor4::allocate(); template <> int Tensor5::allocate(); template <> int Tensor6::allocate(); template <> int Tensor0::allocate(); template <> int Tensor1::allocate(); template <> int Tensor2::allocate(); template <> int Tensor3::allocate(); template <> int Tensor4::allocate(); template <> int Tensor5::allocate(); template <> int Tensor6::allocate(); template <> int Tensor0::copyValueFrom(Tensor* src); template <> int Tensor1::copyValueFrom(Tensor* src); template <> int Tensor2::copyValueFrom(Tensor* src); template <> int Tensor3::copyValueFrom(Tensor* src); template <> int Tensor4::copyValueFrom(Tensor* src); template <> int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); template <> int Tensor0::copyValueFrom(Tensor* src); template <> int Tensor1::copyValueFrom(Tensor* src); template <> int Tensor2::copyValueFrom(Tensor* src); template <> int Tensor3::copyValueFrom(Tensor* src); template <> int Tensor4::copyValueFrom(Tensor* src); template <> int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); template <> int Tensor0::copyValueFrom(Tensor* src); template <> int Tensor1::copyValueFrom(Tensor* src); template <> int Tensor2::copyValueFrom(Tensor* src); template <> int Tensor3::copyValueFrom(Tensor* src); template <> int Tensor4::copyValueFrom(Tensor* src); template <> int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); template <> int Tensor0::copyValueFrom(Tensor* src); template <> int Tensor1::copyValueFrom(Tensor* src); template <> int Tensor2::copyValueFrom(Tensor* src); template <> int Tensor3::copyValueFrom(Tensor* src); template <> int Tensor4::copyValueFrom(Tensor* src); template <> int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); template <> int Tensor0::copyValueFrom(Tensor* src); template <> int Tensor1::copyValueFrom(Tensor* src); template <> int Tensor2::copyValueFrom(Tensor* src); template <> int Tensor3::copyValueFrom(Tensor* src); template <> int Tensor4::copyValueFrom(Tensor* src); template <> int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); template <> int Tensor0::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor1::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor2::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor3::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor4::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor5::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor6::setTensorValueUInt8(const size_t bufLen, const uint8_t* vals); template <> int Tensor0::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor1::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor2::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor3::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor4::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor5::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor6::setTensorValueInt8(const size_t bufLen, const int8_t* vals); template <> int Tensor0::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor1::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor2::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor3::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor4::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor5::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor6::setTensorValueUInt16(const size_t bufLen, const uint16_t* vals); template <> int Tensor0::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor1::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor2::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor3::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor4::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor5::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor6::setTensorValueInt16(const size_t bufLen, const int16_t* vals); template <> int Tensor0::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor1::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor2::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor3::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor4::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor5::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor6::setTensorValueInt32(const size_t bufLen, const int32_t* vals); template <> int Tensor0::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor1::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor2::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor3::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor4::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor5::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor6::getTensorValueUInt8(const size_t bufLen, uint8_t* vals) const; template <> int Tensor0::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor1::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor2::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor3::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor4::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor5::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor6::getTensorValueInt8(const size_t bufLen, int8_t* vals) const; template <> int Tensor0::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor1::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor2::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor3::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor4::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor5::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor6::getTensorValueUInt16(const size_t bufLen, uint16_t* vals) const; template <> int Tensor0::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor1::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor2::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor3::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor4::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor5::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor6::getTensorValueInt16(const size_t bufLen, int16_t* vals) const; template <> int Tensor0::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor1::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor2::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor3::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor4::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor5::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor6::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; template <> int Tensor0::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor1::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor2::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor3::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor4::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor5::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor6::setTensorValueInt64(const size_t bufLen, const int64_t* vals); template <> int Tensor0::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor1::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor2::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor3::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor4::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor5::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor6::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; template <> int Tensor0::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor1::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor2::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor3::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor4::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor5::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor6::setTensorValueFloat(const size_t bufLen, const float* vals); template <> int Tensor0::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor1::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor2::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor3::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor4::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor5::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor6::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor0::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor1::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor2::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor3::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor4::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor5::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor6::setTensorValueDouble(const size_t bufLen, const double* vals); template <> int Tensor0::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor1::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor2::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor3::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor4::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor5::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor6::getTensorValueDouble(const size_t bufLen, double* vals) const; template <> int Tensor0::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor1::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor2::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor3::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor4::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor5::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor6::setTensorValueBool(const size_t bufLen, const bool* vals); template <> int Tensor0::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor1::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor2::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor3::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor4::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor5::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor6::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; template <> int Tensor2::dumpTensor(FILE* out) const; template <> int Tensor3::dumpTensor(FILE* out) const; template <> int Tensor4::dumpTensor(FILE* out) const; template <> int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; template <> int Tensor2::dumpTensor(FILE* out) const; template <> int Tensor3::dumpTensor(FILE* out) const; template <> int Tensor4::dumpTensor(FILE* out) const; template <> int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; template <> int Tensor2::dumpTensor(FILE* out) const; template <> int Tensor3::dumpTensor(FILE* out) const; template <> int Tensor4::dumpTensor(FILE* out) const; template <> int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; template <> int Tensor2::dumpTensor(FILE* out) const; template <> int Tensor3::dumpTensor(FILE* out) const; template <> int Tensor4::dumpTensor(FILE* out) const; template <> int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; template <> int Tensor2::dumpTensor(FILE* out) const; template <> int Tensor3::dumpTensor(FILE* out) const; template <> int Tensor4::dumpTensor(FILE* out) const; template <> int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; class TensorFactory { public: static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector shape_, const uint32_t rank) { TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_); switch (tensorDtype_) { case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: case TOSA_REF_TYPE_FP8E4M3: case TOSA_REF_TYPE_FP8E5M2: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); case 2: return new Tensor2(tensorName_, dtype_, shape_); case 3: return new Tensor3(tensorName_, dtype_, shape_); case 4: return new Tensor4(tensorName_, dtype_, shape_); case 5: return new Tensor5(tensorName_, dtype_, shape_); case 6: return new Tensor6(tensorName_, dtype_, shape_); } break; case TOSA_REF_TYPE_INT32: case TOSA_REF_TYPE_UINT8: case TOSA_REF_TYPE_INT4: case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: case TOSA_REF_TYPE_UINT16: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); case 2: return new Tensor2(tensorName_, dtype_, shape_); case 3: return new Tensor3(tensorName_, dtype_, shape_); case 4: return new Tensor4(tensorName_, dtype_, shape_); case 5: return new Tensor5(tensorName_, dtype_, shape_); case 6: return new Tensor6(tensorName_, dtype_, shape_); } break; case TOSA_REF_TYPE_INT48: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); case 2: return new Tensor2(tensorName_, dtype_, shape_); case 3: return new Tensor3(tensorName_, dtype_, shape_); case 4: return new Tensor4(tensorName_, dtype_, shape_); case 5: return new Tensor5(tensorName_, dtype_, shape_); case 6: return new Tensor6(tensorName_, dtype_, shape_); } break; case TOSA_REF_TYPE_SHAPE: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); default: assert(0); // shape tensors must have rank of 0 or 1 } break; case TOSA_REF_TYPE_BOOL: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); case 2: return new Tensor2(tensorName_, dtype_, shape_); case 3: return new Tensor3(tensorName_, dtype_, shape_); case 4: return new Tensor4(tensorName_, dtype_, shape_); case 5: return new Tensor5(tensorName_, dtype_, shape_); case 6: return new Tensor6(tensorName_, dtype_, shape_); } break; case TOSA_REF_TYPE_FP64: switch (rank) { case 0: return new Tensor0(tensorName_, dtype_, shape_); case 1: return new Tensor1(tensorName_, dtype_, shape_); case 2: return new Tensor2(tensorName_, dtype_, shape_); case 3: return new Tensor3(tensorName_, dtype_, shape_); case 4: return new Tensor4(tensorName_, dtype_, shape_); case 5: return new Tensor5(tensorName_, dtype_, shape_); case 6: return new Tensor6(tensorName_, dtype_, shape_); } break; case TOSA_REF_TYPE_UNKNOWN: assert(0); // tensorDtype_ is uninitialized break; } return nullptr; } }; }; // namespace TosaReference #endif