diff options
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r-- | reference_model/src/tensor.h | 163 |
1 files changed, 31 insertions, 132 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 4f77cfc..d39cc7c 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -35,10 +35,7 @@ class Tensor public: Tensor(std::string tensorName_, DType tensorDtype__, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, - std::vector<int> shape_, - int isConst_); + std::vector<int> shape_); virtual ~Tensor(); @@ -75,11 +72,6 @@ public: return isValid; } - int getIsConst() const - { - return isConst; - } - GraphNode* getProducer() { return producer; @@ -111,62 +103,6 @@ public: return shape_str; } - const std::vector<Usage>& getUsage() const - { - return tensorUsage; - } - - bool hasUsage(Usage usage) const - { - for (auto& usg : tensorUsage) - { - if (usg == usage) - { - return true; - } - } - return false; - } - - std::string getUsageAsString() const - { - std::string usage_str("["); - for (auto& usg : tensorUsage) - { - usage_str += (std::string(EnumNamesUsage()[usg]) + ", "); - } - usage_str.append("]"); - return usage_str; - } - - const std::vector<Format>& getFormat() const - { - return tensorFormat; - } - - bool hasFormat(Format format) const - { - for (auto& fmt : tensorFormat) - { - if (fmt == format) - { - return true; - } - } - return false; - } - - std::string getFormatAsString() const - { - std::string format_str("["); - for (auto& fmt : tensorFormat) - { - format_str += (std::string(EnumNamesFormat()[fmt]) + ", "); - } - format_str.append("]"); - return format_str; - } - const uint32_t getElementCount() const { uint32_t elements = 1; @@ -282,9 +218,6 @@ public: protected: std::string tensorName; DType tensorDtype; - std::vector<Usage> tensorUsage; - std::vector<Format> tensorFormat; - int isConst; int isValid; std::vector<int> shape; int isSubgraphInput; @@ -309,11 +242,8 @@ class TensorTemplate : public Tensor public: TensorTemplate(std::string tensorName_, DType tensorDtype_, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, - std::vector<int> shape_, - int isConst_) - : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_) + std::vector<int> shape_) + : Tensor(tensorName_, tensorDtype_, shape_) { tensor = nullptr; } @@ -678,10 +608,7 @@ class TensorFactory public: static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, std::vector<int> shape_, - int isConst_, const uint32_t rank) { switch (tensorDtype_) @@ -690,26 +617,19 @@ public: switch (rank) { case 0: - return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<float>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<float>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<float>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<float>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<float>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<float>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<float>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -721,26 +641,19 @@ public: switch (rank) { case 0: - return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -748,26 +661,19 @@ public: switch (rank) { case 0: - return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -775,26 +681,19 @@ public: switch (rank) { case 0: - return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<bool>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<bool>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<bool>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<bool>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<bool>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<bool>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<bool>(tensorName_, tensorDtype_, shape_); default: goto done; } |