aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r--reference_model/src/tensor.h815
1 files changed, 815 insertions, 0 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
new file mode 100644
index 0000000..2fd37cd
--- /dev/null
+++ b/reference_model/src/tensor.h
@@ -0,0 +1,815 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_TENSOR_H
+#define TOSA_REFERENCE_TENSOR_H
+
+#include "model_common.h"
+#include "ops/template_types.h"
+#include "tosa_generated.h"
+#include "tosa_serialization_handler.h"
+#include <Eigen/CXX11/Tensor>
+#include <list>
+#include <vector>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+class GraphNode;
+
+class Tensor
+{
+public:
+ Tensor(std::string tensorName_,
+ DType tensorDtype__,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_);
+
+ virtual ~Tensor();
+
+ int setIsSubgraphInput();
+ int setIsSubgraphOutput();
+
+ int getIsSubgraphInput() const
+ {
+ return isSubgraphInput;
+ }
+
+ int getIsSubgraphOutput() const
+ {
+ return isSubgraphOutput;
+ }
+
+ int setProducer(GraphNode* node);
+ int addConsumer(GraphNode* node);
+
+ int setIsValid()
+ {
+ isValid = 1;
+ return 0;
+ }
+
+ int clearIsValid()
+ {
+ isValid = 0;
+ return 0;
+ }
+
+ int getIsValid() const
+ {
+ return isValid;
+ }
+
+ int getIsConst() const
+ {
+ return isConst;
+ }
+
+ GraphNode* getProducer()
+ {
+ return producer;
+ }
+
+ std::vector<GraphNode*>& getConsumers()
+ {
+ return consumers;
+ }
+
+ const std::string& getName() const
+ {
+ return tensorName;
+ }
+
+ const std::vector<int>& getShape() const
+ {
+ return shape;
+ }
+
+ std::string getShapeAsString() const
+ {
+ std::string shape_str("[");
+ for (auto& dim : shape)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+ return shape_str;
+ }
+
+ const std::vector<Usage>& getUsage() const
+ {
+ return tensorUsage;
+ }
+
+ bool hasUsage(Usage usage) const
+ {
+ for (auto& usg : tensorUsage)
+ {
+ if (usg == usage)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getUsageAsString() const
+ {
+ std::string usage_str("[");
+ for (auto& usg : tensorUsage)
+ {
+ usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
+ }
+ usage_str.append("]");
+ return usage_str;
+ }
+
+ const std::vector<Format>& getFormat() const
+ {
+ return tensorFormat;
+ }
+
+ bool hasFormat(Format format) const
+ {
+ for (auto& fmt : tensorFormat)
+ {
+ if (fmt == format)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getFormatAsString() const
+ {
+ std::string format_str("[");
+ for (auto& fmt : tensorFormat)
+ {
+ format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
+ }
+ format_str.append("]");
+ return format_str;
+ }
+
+ const uint32_t getElementCount() const
+ {
+ uint32_t elements = 1;
+ for (size_t i = 0; i < shape.size(); i++)
+ elements *= shape[i];
+
+ return elements;
+ }
+
+ // Comparison of rank and type with other tensors
+ const int matchRank(const Tensor& ref) const
+ {
+ return (ref.shape.size() == shape.size()) ? 0 : 1;
+ }
+
+ const int matchType(const Tensor& ref) const
+ {
+ return (ref.tensorDtype == tensorDtype) ? 0 : 1;
+ }
+
+ const int matchRankType(const Tensor& ref) const
+ {
+ return (matchType(ref) || matchRank(ref));
+ }
+
+ const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
+ {
+ if (matchRankType(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ {
+ if (!broadcastOk ||
+ // For broadcasts, at least one operand must have size 1
+ // if they don't both match
+ (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ {
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ // Sometimes we might want to match several semi-compatible types,
+ // so just check rank and size here
+ const int matchRankSize(const Tensor& ref) const
+ {
+ if (matchRank(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ return 1;
+ }
+
+ return 0;
+ }
+
+ // Unary check to make sure rank matches
+ const int checkRequiredRank(const int exactRank) const
+ {
+ return (shape.size() == (size_t)exactRank) ? 0 : 1;
+ }
+
+ const int checkRequiredRank(const int minRank, const int maxRank) const
+ {
+ return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
+ }
+
+ const int getRank() const
+ {
+ return shape.size();
+ }
+
+ const DType getDtype() const
+ {
+ return tensorDtype;
+ }
+
+ virtual int dumpTensor(FILE* out) const = 0;
+ virtual int dumpTensorParams(FILE* out) const;
+ virtual int dumpTensorParams(std::ostream& out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
+ virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
+
+ virtual int readFromNpyFile(const char* filename);
+ virtual int writeToNpyFile(const char* filename) const;
+ virtual int copyValueFrom(Tensor* tensor) = 0;
+
+ const char* bool_to_str(bool in) const
+ {
+ static const char* true_str = "true";
+ static const char* false_str = "false";
+ return in ? true_str : false_str;
+ }
+
+ virtual int allocate() = 0;
+ virtual int deallocate() = 0;
+ virtual bool is_allocated() = 0;
+
+protected:
+ std::string tensorName;
+ DType tensorDtype;
+ std::vector<Usage> tensorUsage;
+ std::vector<Format> tensorFormat;
+ int isConst;
+ int isValid;
+ std::vector<int> shape;
+ int isSubgraphInput;
+ int isSubgraphOutput;
+ bool isAllocated;
+
+ GraphNode* producer;
+ std::vector<GraphNode*> consumers;
+
+ // Note: the Eigen::Tensor is not declared in Tensor
+ // Instead, the TensorTemplate class keeps the templated tensor
+ // declaration so that the graph manipulation tools are isolated
+ // from the templated tensor type.
+ //
+ // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
+ // so that they can operate on the right types.
+};
+
+template <class T>
+class TensorTemplate : public Tensor
+{
+public:
+ TensorTemplate(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_)
+ : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
+ {
+ tensor = nullptr;
+ }
+
+ virtual ~TensorTemplate()
+ {
+ deallocate();
+ }
+
+ virtual int allocate()
+ {
+ tensor = new T();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+ }
+
+ virtual int deallocate()
+ {
+ if (tensor)
+ {
+ delete tensor;
+ }
+ tensor = nullptr;
+ return 0;
+ }
+
+ virtual bool is_allocated()
+ {
+ if (tensor)
+ {
+ return true;
+ }
+ return false;
+ }
+
+ T& getTensor()
+ {
+ return *tensor;
+ }
+
+ virtual int dumpTensor(FILE* out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
+ virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
+
+ virtual int copyValueFrom(Tensor* tensor);
+
+protected:
+ T* tensor;
+};
+
+// allocate() template specializations to allocate the different tensor sizes
+// Let the compiler know here before the factory uses them, but define them in the .cc file.
+template <>
+int Tensor0<float>::allocate();
+template <>
+int Tensor1<float>::allocate();
+template <>
+int Tensor2<float>::allocate();
+template <>
+int Tensor3<float>::allocate();
+template <>
+int Tensor4<float>::allocate();
+template <>
+int Tensor5<float>::allocate();
+template <>
+int Tensor6<float>::allocate();
+
+template <>
+int Tensor0<int32_t>::allocate();
+template <>
+int Tensor1<int32_t>::allocate();
+template <>
+int Tensor2<int32_t>::allocate();
+template <>
+int Tensor3<int32_t>::allocate();
+template <>
+int Tensor4<int32_t>::allocate();
+template <>
+int Tensor5<int32_t>::allocate();
+template <>
+int Tensor6<int32_t>::allocate();
+
+template <>
+int Tensor0<int64_t>::allocate();
+template <>
+int Tensor1<int64_t>::allocate();
+template <>
+int Tensor2<int64_t>::allocate();
+template <>
+int Tensor3<int64_t>::allocate();
+template <>
+int Tensor4<int64_t>::allocate();
+template <>
+int Tensor5<int64_t>::allocate();
+template <>
+int Tensor6<int64_t>::allocate();
+
+template <>
+int Tensor0<bool>::allocate();
+template <>
+int Tensor1<bool>::allocate();
+template <>
+int Tensor2<bool>::allocate();
+template <>
+int Tensor3<bool>::allocate();
+template <>
+int Tensor4<bool>::allocate();
+template <>
+int Tensor5<bool>::allocate();
+template <>
+int Tensor6<bool>::allocate();
+
+template <>
+int Tensor0<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<float>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int32_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int64_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<bool>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+
+template <>
+int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+
+template <>
+int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+
+template <>
+int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+
+template <>
+int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+
+template <>
+int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+
+template <>
+int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+
+template <>
+int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+
+// assume we only dump float type tensor now
+template <>
+int Tensor0<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<bool>::dumpTensor(FILE* out) const;
+
+class TensorFactory
+{
+public:
+ static Tensor* newTensor(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_,
+ const uint32_t rank)
+ {
+ switch (tensorDtype_)
+ {
+ case DType_FLOAT:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT48:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_BOOL:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ default:
+ goto done;
+ }
+
+ done:
+ FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
+ rank);
+ }
+
+ static Tensor* newTensor(DType type, const std::vector<int> shape);
+};
+}; // namespace TosaReference
+
+#endif