aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-03 11:21:43 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-04-27 16:01:59 -0700
commit550ccc52de231621c0bf0c05ae2a398eec37ff51 (patch)
treed4a5bd8d24560135784208c0fe35615b1d043249 /reference_model/src/tensor.h
parentcf6224e6e8ba4fc2984de3e542538c38e27c9f57 (diff)
downloadreference_model-550ccc52de231621c0bf0c05ae2a398eec37ff51.tar.gz
Replace serialization/ and verif/ with MLPlatform's serialization_lib submodule
- Remove Usage and Format - Run black on verif/*.py scripts Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ie81515891eb0039540f614894f4b6b0e0e78ba74
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r--reference_model/src/tensor.h163
1 files changed, 31 insertions, 132 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 4f77cfc..d39cc7c 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -35,10 +35,7 @@ class Tensor
public:
Tensor(std::string tensorName_,
DType tensorDtype__,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
- std::vector<int> shape_,
- int isConst_);
+ std::vector<int> shape_);
virtual ~Tensor();
@@ -75,11 +72,6 @@ public:
return isValid;
}
- int getIsConst() const
- {
- return isConst;
- }
-
GraphNode* getProducer()
{
return producer;
@@ -111,62 +103,6 @@ public:
return shape_str;
}
- const std::vector<Usage>& getUsage() const
- {
- return tensorUsage;
- }
-
- bool hasUsage(Usage usage) const
- {
- for (auto& usg : tensorUsage)
- {
- if (usg == usage)
- {
- return true;
- }
- }
- return false;
- }
-
- std::string getUsageAsString() const
- {
- std::string usage_str("[");
- for (auto& usg : tensorUsage)
- {
- usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
- }
- usage_str.append("]");
- return usage_str;
- }
-
- const std::vector<Format>& getFormat() const
- {
- return tensorFormat;
- }
-
- bool hasFormat(Format format) const
- {
- for (auto& fmt : tensorFormat)
- {
- if (fmt == format)
- {
- return true;
- }
- }
- return false;
- }
-
- std::string getFormatAsString() const
- {
- std::string format_str("[");
- for (auto& fmt : tensorFormat)
- {
- format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
- }
- format_str.append("]");
- return format_str;
- }
-
const uint32_t getElementCount() const
{
uint32_t elements = 1;
@@ -282,9 +218,6 @@ public:
protected:
std::string tensorName;
DType tensorDtype;
- std::vector<Usage> tensorUsage;
- std::vector<Format> tensorFormat;
- int isConst;
int isValid;
std::vector<int> shape;
int isSubgraphInput;
@@ -309,11 +242,8 @@ class TensorTemplate : public Tensor
public:
TensorTemplate(std::string tensorName_,
DType tensorDtype_,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
- std::vector<int> shape_,
- int isConst_)
- : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
+ std::vector<int> shape_)
+ : Tensor(tensorName_, tensorDtype_, shape_)
{
tensor = nullptr;
}
@@ -678,10 +608,7 @@ class TensorFactory
public:
static Tensor* newTensor(std::string tensorName_,
DType tensorDtype_,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
std::vector<int> shape_,
- int isConst_,
const uint32_t rank)
{
switch (tensorDtype_)
@@ -690,26 +617,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -721,26 +641,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -748,26 +661,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -775,26 +681,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}