aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-03 11:21:43 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-04-27 16:01:59 -0700
commit550ccc52de231621c0bf0c05ae2a398eec37ff51 (patch)
treed4a5bd8d24560135784208c0fe35615b1d043249 /reference_model
parentcf6224e6e8ba4fc2984de3e542538c38e27c9f57 (diff)
downloadreference_model-550ccc52de231621c0bf0c05ae2a398eec37ff51.tar.gz
Replace serialization/ and verif/ with MLPlatform's serialization_lib submodule
- Remove Usage and Format - Run black on verif/*.py scripts Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ie81515891eb0039540f614894f4b6b0e0e78ba74
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/CMakeLists.txt8
-rw-r--r--reference_model/src/main.cpp2
-rw-r--r--reference_model/src/ops/control_flow.cc5
-rw-r--r--reference_model/src/ops/tensor_ops.cc63
-rw-r--r--reference_model/src/subgraph_traverser.cc223
-rw-r--r--reference_model/src/tensor.cc17
-rw-r--r--reference_model/src/tensor.h163
7 files changed, 154 insertions, 327 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 0ba8afb..153a5bd 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -26,8 +26,8 @@ else()
set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes")
endif()
-set(FLATBUFFERS_DIR "../thirdparty/flatbuffers/")
-set(SERIALIZATION_DIR "../serialization")
+set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/")
+set(SERIALIZATION_DIR "../thirdparty/serialization_lib/")
set (CXX_SOURCE
src/main.cpp
@@ -64,13 +64,13 @@ target_include_directories(tosa_reference_model
${FLATBUFFERS_DIR}/include
../thirdparty/eigen/
../thirdparty/eigen/unsupported/
- ${SERIALIZATION_DIR}
+ ${SERIALIZATION_DIR}/include
)
target_link_libraries(tosa_reference_model
PRIVATE
+ tosa_serialization_lib
flatbuffers
- tosa_serialization
)
install (TARGETS tosa_reference_model DESTINATION bin)
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index ec2fdc9..240d913 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -15,8 +15,6 @@
#include <stdio.h>
-#include "flatbuffers/idl.h"
-#include "flatbuffers/util.h"
#include "model_common.h"
#include "ops/op_factory.h"
#include "subgraph_traverser.h"
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 9d5db40..827e01f 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -292,9 +292,8 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
- TosaReference::Tensor0<bool> cond_output_ctensor(
- std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }),
- std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false);
+ TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL,
+ std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index d6cd1cd..b8c7ade 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -103,12 +103,6 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- if (!in->hasFormat(Format_NHWC))
- {
- printNodeValidationError("OpAvgPool2d: unsupported tensor format");
- return 1;
- }
-
if (attribute->padding().size() != 4)
{
printNodeValidationError("OpAvgPool2d: illegal size for attribute padding");
@@ -321,28 +315,11 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpConv2d: bias tensor must be rank 1");
}
- if (inputs[1]->getIsConst() == 0)
- {
- printNodeValidationError("OpConv2d: weight tensor is not const typed");
- }
-
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (!input->hasFormat(Format_NHWC))
- {
- printNodeValidationError("OpConv2d: unsupported input tensor format");
- return 1;
- }
-
- if (!weight->hasFormat(Format_OHWI))
- {
- printNodeValidationError("OpConv2d: unsupported weight tensor format");
- return 1;
- }
-
if (attribute->padding().size() != 4)
{
printNodeValidationError("OpConv2d: illegal size for attribute padding");
@@ -530,28 +507,11 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
}
- if (inputs[1]->getIsConst() == 0)
- {
- printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed");
- }
-
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (!input->hasFormat(Format_NHWC))
- {
- printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format");
- return 1;
- }
-
- if (!weight->hasFormat(Format_HWIM))
- {
- printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format");
- return 1;
- }
-
if (attribute->padding().size() != 4)
{
printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
@@ -881,12 +841,6 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- if (!in->hasFormat(Format_NHWC))
- {
- printNodeValidationError("OpMaxPool2d: unsupported tensor format");
- return 1;
- }
-
if (attribute->padding().size() != 4)
{
printNodeValidationError("OpMaxPool2d: illegal size for attribute padding");
@@ -1021,28 +975,11 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (inputs[1]->getIsConst() == 0)
- {
- printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed");
- }
-
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
- if (!input->hasFormat(Format_NHWC))
- {
- printNodeValidationError("OpTransposeConv2d: unsupported input tensor format");
- return 1;
- }
-
- if (!weight->hasFormat(Format_OHWI))
- {
- printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format");
- return 1;
- }
-
if (attribute->outpad().size() != 2)
{
printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad");
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 082f802..5096ffa 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -103,110 +103,118 @@ int SubgraphTraverser::initializeGraph()
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
- DType in_dtype = DType_UNKNOWN, out_dtype = DType_UNKNOWN, weight_dtype = DType_UNKNOWN;
- uint32_t in_rank = 0, out_rank = 0, weight_rank = 0;
- for (auto name : op->GetInputTensorNames())
- {
-
- TosaSerializationTensor* ts = block->GetTensorByName(name);
- ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
-
- if (ts->HasUsage(Usage_WEIGHT))
- {
- weight_dtype = ts->GetDtype();
- weight_rank = ts->GetShape().size();
- }
- else if (ts->HasUsage(Usage_INDEX))
- {
- // do nothing, but this will prevent tensor's dtype/rank being wrongly used as template argument when initializing this op
- }
- else if (ts->HasUsage(Usage_ACTIVATION))
- {
- if (ts->GetShape().size() >= in_rank)
- {
- in_dtype = ts->GetDtype();
- in_rank = ts->GetShape().size();
- }
- }
- }
-
- // if dtype/rank still not initialized with above pass, we initialize without Usage check
- if (in_dtype == DType_UNKNOWN && in_rank == 0)
- {
- for (auto name : op->GetInputTensorNames())
- {
- TosaSerializationTensor* ts = block->GetTensorByName(name);
- ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
-
- if (ts->GetShape().size() >= in_rank)
- {
- in_dtype = ts->GetDtype();
- in_rank = ts->GetShape().size();
- }
- }
- }
-
- for (auto name : op->GetOutputTensorNames())
- {
-
- TosaSerializationTensor* ts = block->GetTensorByName(name);
- ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
-
- out_dtype = ts->GetDtype();
- out_rank = ts->GetShape().size();
- }
+ DType input_dtype = DType_UNKNOWN;
+ DType output_dtype = DType_UNKNOWN;
+ DType weight_dtype = DType_UNKNOWN;
+ uint32_t input_rank = 0;
+ uint32_t output_rank = 0;
+ uint32_t weight_rank = 0;
+ int32_t input_index = -1;
+ int32_t weight_index = -1;
+
+ switch (op->GetOp())
+ {
+ case Op_CONV2D:
+ case Op_DEPTHWISE_CONV2D:
+ case Op_TRANSPOSE_CONV2D:
+ case Op_FULLY_CONNECTED:
+ input_index = 0;
+ weight_index = 1;
+ break;
+ case Op_SELECT:
+ input_index = 1;
+ break;
+ default:
+ if (!op->GetInputTensorNames().empty())
+ input_index = 0;
+ break;
+ }
+
+ if (input_index != -1)
+ {
+ ASSERT_MSG((size_t)input_index < op->GetInputTensorNames().size(),
+ "Op=%s, input_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()],
+ input_index);
+
+ std::string input_name = op->GetInputTensorNames()[input_index];
+ TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
+ ASSERT_MSG(input_tensor, "SubgraphTraverser: fail to get input tensor %s from TosaSerializationHandler",
+ input_name.c_str());
+ input_dtype = input_tensor->GetDtype();
+ input_rank = input_tensor->GetShape().size();
+ }
+
+ if (weight_index != -1)
+ {
+ ASSERT_MSG((size_t)weight_index < op->GetInputTensorNames().size(),
+ "Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()],
+ weight_index);
+ std::string weight_name = op->GetInputTensorNames()[weight_index];
+ TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
+ ASSERT_MSG(weight_tensor, "SubgraphTraverser: fail to get weight tensor %s from TosaSerializationHandler",
+ weight_name.c_str());
+ weight_dtype = weight_tensor->GetDtype();
+ weight_rank = weight_tensor->GetShape().size();
+ }
+
+ std::string output_name = op->GetOutputTensorNames()[0];
+ TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
+ ASSERT_MSG(output_tensor, "SubgraphTraverser: fail to get output tensor %s from TosaSerializationHandler",
+ output_name.c_str());
+ output_dtype = output_tensor->GetDtype();
+ output_rank = output_tensor->GetShape().size();
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* cn = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, in_dtype, in_rank,
- out_dtype, out_rank, weight_dtype, weight_rank);
- if (!cn)
+ GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ if (!node)
{
- if (weight_dtype == DType_UNKNOWN && weight_rank == 0)
+ if (weight_index == -1)
{
fprintf(g_func_debug.func_debug_file,
"OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[out_dtype],
- out_rank);
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
+ EnumNamesDType()[output_dtype], output_rank);
}
else
{
fprintf(g_func_debug.func_debug_file,
"OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[weight_dtype],
- weight_rank, EnumNamesDType()[out_dtype], out_rank);
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
+ EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank);
}
- for (auto ts : op->GetInputTensors())
+ for (auto& ts : op->GetInputTensorNames())
{
- fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts->GetName().c_str());
+ fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts.c_str());
}
- for (auto ts : op->GetOutputTensors())
+ for (auto& ts : op->GetOutputTensorNames())
{
- fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts->GetName().c_str());
+ fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts.c_str());
}
FATAL_ERROR("Unsupported operation type or rank.");
}
- for (auto name : op->GetInputTensorNames())
+ for (auto& name : op->GetInputTensorNames())
{
- cn->addInputName(name);
+ node->addInputName(name);
}
for (auto name : op->GetOutputTensorNames())
{
- cn->addOutputName(name);
+ node->addOutputName(name);
}
- addNode(cn);
+ addNode(node);
// if node doesn't have any inputs (i.e. CONST)
// it should be ready for evaluation
- if (op->GetInputTensorNames().empty() && !cn->getOnNextNodeList())
+ if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
- addToNextNodeList(cn);
+ addToNextNodeList(node);
}
idx++;
@@ -215,47 +223,40 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
- bool is_const = false;
- if (ts->HasUsage(Usage_WEIGHT))
- {
- is_const = true;
- }
-
DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
- TosaReference::Tensor* ct =
- TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetUsage(), ts->GetFormat(), ts->GetShape(),
- is_const, ts->GetShape().size());
+ TosaReference::Tensor* tensor =
+ TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
- if (ts->GetNpyFilePtr())
+ if (!ts->GetNpyFilePtr().empty())
{
- if (ct->allocate())
+ if (tensor->allocate())
{
- FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str());
}
bzero(tensor_fullname, sizeof(tensor_fullname));
snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.subgraph_dir,
- ts->GetNpyFilePtr()->c_str());
- if (ct->readFromNpyFile(tensor_fullname))
+ ts->GetNpyFilePtr().c_str());
+ if (tensor->readFromNpyFile(tensor_fullname))
{
- FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", ct->getName().c_str(),
+ FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", tensor->getName().c_str(),
block->GetName().c_str());
}
}
// update this->tensors
- addTensor(ct);
+ addTensor(tensor);
}
DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
for (auto& input_name : block->GetInputs())
{
- TosaReference::Tensor* ct = findTensorByName(input_name);
+ TosaReference::Tensor* tensor = findTensorByName(input_name);
DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
- if (ct)
+ if (tensor)
{
- ct->setIsSubgraphInput();
- inputTensors.push_back(ct);
+ tensor->setIsSubgraphInput();
+ inputTensors.push_back(tensor);
}
else
{
@@ -266,12 +267,12 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
for (auto& output_name : block->GetOutputs())
{
- TosaReference::Tensor* ct = findTensorByName(output_name);
+ TosaReference::Tensor* tensor = findTensorByName(output_name);
DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
- if (ct)
+ if (tensor)
{
- ct->setIsSubgraphOutput();
- outputTensors.push_back(ct);
+ tensor->setIsSubgraphOutput();
+ outputTensors.push_back(tensor);
}
else
{
@@ -333,12 +334,12 @@ int SubgraphTraverser::evaluateNextNode()
WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount());
}
- for (auto ct : currNode->getOutputs())
+ for (auto tensor : currNode->getOutputs())
{
- if (!ct->is_allocated())
- if (ct->allocate())
+ if (!tensor->is_allocated())
+ if (tensor->allocate())
{
- FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str());
}
}
@@ -348,26 +349,26 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- for (auto ct : currNode->getInputs())
+ for (auto tensor : currNode->getInputs())
{
bool in_use = false;
- for (auto cn : ct->getConsumers())
+ for (auto node : tensor->getConsumers())
{
- if (!cn->hasAllOutputsReady())
+ if (!node->hasAllOutputsReady())
{
in_use = true;
}
}
for (auto name : block->GetOutputs())
{
- if (name == ct->getName())
+ if (name == tensor->getName())
{
in_use = true;
}
}
if (!in_use)
{
- ct->deallocate();
+ tensor->deallocate();
}
}
@@ -433,29 +434,29 @@ int SubgraphTraverser::clearAllNodeMarkings()
return false;
}
-int SubgraphTraverser::addTensor(TosaReference::Tensor* ct)
+int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
{
// Enforce no duplicate tensors/tensor names
// O(N), but the number of tensors is small
for (TosaReference::Tensor* currTensor : tensors)
{
- if (ct == currTensor || currTensor->getName() == ct->getName())
+ if (tensor == currTensor || currTensor->getName() == tensor->getName())
{
- FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", ct->getName().c_str());
+ FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", tensor->getName().c_str());
return 1;
}
}
- tensors.push_back(ct);
+ tensors.push_back(tensor);
- if (ct->getIsSubgraphInput())
+ if (tensor->getIsSubgraphInput())
{
- inputTensors.push_back(ct);
+ inputTensors.push_back(tensor);
}
- if (ct->getIsSubgraphOutput())
+ if (tensor->getIsSubgraphOutput())
{
- outputTensors.push_back(ct);
+ outputTensors.push_back(tensor);
}
return 0;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index da81bcd..1efebe3 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -22,17 +22,11 @@ using namespace tosa;
TosaReference::Tensor::Tensor(std::string tensorName_,
DType tensorDtype_,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
- std::vector<int> shape_,
- int isConst_)
+ std::vector<int> shape_)
{
tensorName = std::string(tensorName_);
tensorDtype = tensorDtype_;
- tensorUsage = std::vector<Usage>(tensorUsage_);
- tensorFormat = std::vector<Format>(tensorFormat_);
shape = std::vector<int>(shape_);
- isConst = isConst_;
producer = nullptr;
isValid = false;
consumers.clear();
@@ -74,17 +68,16 @@ int TosaReference::Tensor::addConsumer(GraphNode* node)
int TosaReference::Tensor::dumpTensorParams(FILE* out) const
{
- fprintf(out, "Name: %s DType=%s Usage=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(),
- EnumNamesDType()[getDtype()], getUsageAsString().c_str(), getIsValid(), getRank(),
- getShapeAsString().c_str());
+ fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNamesDType()[getDtype()],
+ getIsValid(), getRank(), getShapeAsString().c_str());
return 0;
}
int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
{
- out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " Usage=" << getUsageAsString()
- << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n";
+ out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " isValid=" << getIsValid()
+ << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n";
return 0;
}
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 4f77cfc..d39cc7c 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -35,10 +35,7 @@ class Tensor
public:
Tensor(std::string tensorName_,
DType tensorDtype__,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
- std::vector<int> shape_,
- int isConst_);
+ std::vector<int> shape_);
virtual ~Tensor();
@@ -75,11 +72,6 @@ public:
return isValid;
}
- int getIsConst() const
- {
- return isConst;
- }
-
GraphNode* getProducer()
{
return producer;
@@ -111,62 +103,6 @@ public:
return shape_str;
}
- const std::vector<Usage>& getUsage() const
- {
- return tensorUsage;
- }
-
- bool hasUsage(Usage usage) const
- {
- for (auto& usg : tensorUsage)
- {
- if (usg == usage)
- {
- return true;
- }
- }
- return false;
- }
-
- std::string getUsageAsString() const
- {
- std::string usage_str("[");
- for (auto& usg : tensorUsage)
- {
- usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
- }
- usage_str.append("]");
- return usage_str;
- }
-
- const std::vector<Format>& getFormat() const
- {
- return tensorFormat;
- }
-
- bool hasFormat(Format format) const
- {
- for (auto& fmt : tensorFormat)
- {
- if (fmt == format)
- {
- return true;
- }
- }
- return false;
- }
-
- std::string getFormatAsString() const
- {
- std::string format_str("[");
- for (auto& fmt : tensorFormat)
- {
- format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
- }
- format_str.append("]");
- return format_str;
- }
-
const uint32_t getElementCount() const
{
uint32_t elements = 1;
@@ -282,9 +218,6 @@ public:
protected:
std::string tensorName;
DType tensorDtype;
- std::vector<Usage> tensorUsage;
- std::vector<Format> tensorFormat;
- int isConst;
int isValid;
std::vector<int> shape;
int isSubgraphInput;
@@ -309,11 +242,8 @@ class TensorTemplate : public Tensor
public:
TensorTemplate(std::string tensorName_,
DType tensorDtype_,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
- std::vector<int> shape_,
- int isConst_)
- : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
+ std::vector<int> shape_)
+ : Tensor(tensorName_, tensorDtype_, shape_)
{
tensor = nullptr;
}
@@ -678,10 +608,7 @@ class TensorFactory
public:
static Tensor* newTensor(std::string tensorName_,
DType tensorDtype_,
- const std::vector<Usage>& tensorUsage_,
- const std::vector<Format>& tensorFormat_,
std::vector<int> shape_,
- int isConst_,
const uint32_t rank)
{
switch (tensorDtype_)
@@ -690,26 +617,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -721,26 +641,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -748,26 +661,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}
@@ -775,26 +681,19 @@ public:
switch (rank)
{
case 0:
- return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
case 1:
- return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
case 2:
- return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
case 3:
- return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
case 4:
- return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
case 5:
- return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
case 6:
- return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
- isConst_);
+ return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
default:
goto done;
}