diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-03-03 11:21:43 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-04-27 16:01:59 -0700 |
commit | 550ccc52de231621c0bf0c05ae2a398eec37ff51 (patch) | |
tree | d4a5bd8d24560135784208c0fe35615b1d043249 /reference_model/src | |
parent | cf6224e6e8ba4fc2984de3e542538c38e27c9f57 (diff) | |
download | reference_model-550ccc52de231621c0bf0c05ae2a398eec37ff51.tar.gz |
Replace serialization/ and verif/ with MLPlatform's serialization_lib submodule
- Remove Usage and Format
- Run black on verif/*.py scripts
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ie81515891eb0039540f614894f4b6b0e0e78ba74
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/main.cpp | 2 | ||||
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 5 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 63 | ||||
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 223 | ||||
-rw-r--r-- | reference_model/src/tensor.cc | 17 | ||||
-rw-r--r-- | reference_model/src/tensor.h | 163 |
6 files changed, 150 insertions, 323 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index ec2fdc9..240d913 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -15,8 +15,6 @@ #include <stdio.h> -#include "flatbuffers/idl.h" -#include "flatbuffers/util.h" #include "model_common.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 9d5db40..827e01f 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -292,9 +292,8 @@ int OpWhileLoop::checkTensorAttributes() int OpWhileLoop::eval() { - TosaReference::Tensor0<bool> cond_output_ctensor( - std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }), - std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false); + TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, + std::vector<int32_t>({})); cond_output_ctensor.allocate(); std::vector<TosaReference::Tensor*> cond_block_outputs; diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index d6cd1cd..b8c7ade 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -103,12 +103,6 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - if (!in->hasFormat(Format_NHWC)) - { - printNodeValidationError("OpAvgPool2d: unsupported tensor format"); - return 1; - } - if (attribute->padding().size() != 4) { printNodeValidationError("OpAvgPool2d: illegal size for attribute padding"); @@ -321,28 +315,11 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() printNodeValidationError("OpConv2d: bias tensor must be rank 1"); } - if (inputs[1]->getIsConst() == 0) - { - printNodeValidationError("OpConv2d: weight tensor is not const typed"); - } - input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (!input->hasFormat(Format_NHWC)) - { - printNodeValidationError("OpConv2d: unsupported input tensor format"); - return 1; - } - - if (!weight->hasFormat(Format_OHWI)) - { - printNodeValidationError("OpConv2d: unsupported weight tensor format"); - return 1; - } - if (attribute->padding().size() != 4) { printNodeValidationError("OpConv2d: illegal size for attribute padding"); @@ -530,28 +507,11 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); } - if (inputs[1]->getIsConst() == 0) - { - printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed"); - } - input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (!input->hasFormat(Format_NHWC)) - { - printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format"); - return 1; - } - - if (!weight->hasFormat(Format_HWIM)) - { - printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format"); - return 1; - } - if (attribute->padding().size() != 4) { printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding"); @@ -881,12 +841,6 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - if (!in->hasFormat(Format_NHWC)) - { - printNodeValidationError("OpMaxPool2d: unsupported tensor format"); - return 1; - } - if (attribute->padding().size() != 4) { printNodeValidationError("OpMaxPool2d: illegal size for attribute padding"); @@ -1021,28 +975,11 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes() return 1; } - if (inputs[1]->getIsConst() == 0) - { - printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed"); - } - input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); - if (!input->hasFormat(Format_NHWC)) - { - printNodeValidationError("OpTransposeConv2d: unsupported input tensor format"); - return 1; - } - - if (!weight->hasFormat(Format_OHWI)) - { - printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format"); - return 1; - } - if (attribute->outpad().size() != 2) { printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad"); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 082f802..5096ffa 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -103,110 +103,118 @@ int SubgraphTraverser::initializeGraph() for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode - DType in_dtype = DType_UNKNOWN, out_dtype = DType_UNKNOWN, weight_dtype = DType_UNKNOWN; - uint32_t in_rank = 0, out_rank = 0, weight_rank = 0; - for (auto name : op->GetInputTensorNames()) - { - - TosaSerializationTensor* ts = block->GetTensorByName(name); - ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str()); - - if (ts->HasUsage(Usage_WEIGHT)) - { - weight_dtype = ts->GetDtype(); - weight_rank = ts->GetShape().size(); - } - else if (ts->HasUsage(Usage_INDEX)) - { - // do nothing, but this will prevent tensor's dtype/rank being wrongly used as template argument when initializing this op - } - else if (ts->HasUsage(Usage_ACTIVATION)) - { - if (ts->GetShape().size() >= in_rank) - { - in_dtype = ts->GetDtype(); - in_rank = ts->GetShape().size(); - } - } - } - - // if dtype/rank still not initialized with above pass, we initialize without Usage check - if (in_dtype == DType_UNKNOWN && in_rank == 0) - { - for (auto name : op->GetInputTensorNames()) - { - TosaSerializationTensor* ts = block->GetTensorByName(name); - ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str()); - - if (ts->GetShape().size() >= in_rank) - { - in_dtype = ts->GetDtype(); - in_rank = ts->GetShape().size(); - } - } - } - - for (auto name : op->GetOutputTensorNames()) - { - - TosaSerializationTensor* ts = block->GetTensorByName(name); - ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str()); - - out_dtype = ts->GetDtype(); - out_rank = ts->GetShape().size(); - } + DType input_dtype = DType_UNKNOWN; + DType output_dtype = DType_UNKNOWN; + DType weight_dtype = DType_UNKNOWN; + uint32_t input_rank = 0; + uint32_t output_rank = 0; + uint32_t weight_rank = 0; + int32_t input_index = -1; + int32_t weight_index = -1; + + switch (op->GetOp()) + { + case Op_CONV2D: + case Op_DEPTHWISE_CONV2D: + case Op_TRANSPOSE_CONV2D: + case Op_FULLY_CONNECTED: + input_index = 0; + weight_index = 1; + break; + case Op_SELECT: + input_index = 1; + break; + default: + if (!op->GetInputTensorNames().empty()) + input_index = 0; + break; + } + + if (input_index != -1) + { + ASSERT_MSG((size_t)input_index < op->GetInputTensorNames().size(), + "Op=%s, input_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()], + input_index); + + std::string input_name = op->GetInputTensorNames()[input_index]; + TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name); + ASSERT_MSG(input_tensor, "SubgraphTraverser: fail to get input tensor %s from TosaSerializationHandler", + input_name.c_str()); + input_dtype = input_tensor->GetDtype(); + input_rank = input_tensor->GetShape().size(); + } + + if (weight_index != -1) + { + ASSERT_MSG((size_t)weight_index < op->GetInputTensorNames().size(), + "Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()], + weight_index); + std::string weight_name = op->GetInputTensorNames()[weight_index]; + TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name); + ASSERT_MSG(weight_tensor, "SubgraphTraverser: fail to get weight tensor %s from TosaSerializationHandler", + weight_name.c_str()); + weight_dtype = weight_tensor->GetDtype(); + weight_rank = weight_tensor->GetShape().size(); + } + + std::string output_name = op->GetOutputTensorNames()[0]; + TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name); + ASSERT_MSG(output_tensor, "SubgraphTraverser: fail to get output tensor %s from TosaSerializationHandler", + output_name.c_str()); + output_dtype = output_tensor->GetDtype(); + output_rank = output_tensor->GetShape().size(); DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* cn = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, in_dtype, in_rank, - out_dtype, out_rank, weight_dtype, weight_rank); - if (!cn) + GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, + input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + if (!node) { - if (weight_dtype == DType_UNKNOWN && weight_rank == 0) + if (weight_index == -1) { fprintf(g_func_debug.func_debug_file, "OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[out_dtype], - out_rank); + EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, + EnumNamesDType()[output_dtype], output_rank); } else { fprintf(g_func_debug.func_debug_file, "OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[weight_dtype], - weight_rank, EnumNamesDType()[out_dtype], out_rank); + EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, + EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank); } - for (auto ts : op->GetInputTensors()) + for (auto& ts : op->GetInputTensorNames()) { - fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts->GetName().c_str()); + fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts.c_str()); } - for (auto ts : op->GetOutputTensors()) + for (auto& ts : op->GetOutputTensorNames()) { - fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts->GetName().c_str()); + fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts.c_str()); } FATAL_ERROR("Unsupported operation type or rank."); } - for (auto name : op->GetInputTensorNames()) + for (auto& name : op->GetInputTensorNames()) { - cn->addInputName(name); + node->addInputName(name); } for (auto name : op->GetOutputTensorNames()) { - cn->addOutputName(name); + node->addOutputName(name); } - addNode(cn); + addNode(node); // if node doesn't have any inputs (i.e. CONST) // it should be ready for evaluation - if (op->GetInputTensorNames().empty() && !cn->getOnNextNodeList()) + if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList()) { - addToNextNodeList(cn); + addToNextNodeList(node); } idx++; @@ -215,47 +223,40 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { - bool is_const = false; - if (ts->HasUsage(Usage_WEIGHT)) - { - is_const = true; - } - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* ct = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetUsage(), ts->GetFormat(), ts->GetShape(), - is_const, ts->GetShape().size()); + TosaReference::Tensor* tensor = + TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - if (ts->GetNpyFilePtr()) + if (!ts->GetNpyFilePtr().empty()) { - if (ct->allocate()) + if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str()); + FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); } bzero(tensor_fullname, sizeof(tensor_fullname)); snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.subgraph_dir, - ts->GetNpyFilePtr()->c_str()); - if (ct->readFromNpyFile(tensor_fullname)) + ts->GetNpyFilePtr().c_str()); + if (tensor->readFromNpyFile(tensor_fullname)) { - FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", ct->getName().c_str(), + FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", tensor->getName().c_str(), block->GetName().c_str()); } } // update this->tensors - addTensor(ct); + addTensor(tensor); } DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); for (auto& input_name : block->GetInputs()) { - TosaReference::Tensor* ct = findTensorByName(input_name); + TosaReference::Tensor* tensor = findTensorByName(input_name); DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); - if (ct) + if (tensor) { - ct->setIsSubgraphInput(); - inputTensors.push_back(ct); + tensor->setIsSubgraphInput(); + inputTensors.push_back(tensor); } else { @@ -266,12 +267,12 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); for (auto& output_name : block->GetOutputs()) { - TosaReference::Tensor* ct = findTensorByName(output_name); + TosaReference::Tensor* tensor = findTensorByName(output_name); DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); - if (ct) + if (tensor) { - ct->setIsSubgraphOutput(); - outputTensors.push_back(ct); + tensor->setIsSubgraphOutput(); + outputTensors.push_back(tensor); } else { @@ -333,12 +334,12 @@ int SubgraphTraverser::evaluateNextNode() WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount()); } - for (auto ct : currNode->getOutputs()) + for (auto tensor : currNode->getOutputs()) { - if (!ct->is_allocated()) - if (ct->allocate()) + if (!tensor->is_allocated()) + if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str()); + FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); } } @@ -348,26 +349,26 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - for (auto ct : currNode->getInputs()) + for (auto tensor : currNode->getInputs()) { bool in_use = false; - for (auto cn : ct->getConsumers()) + for (auto node : tensor->getConsumers()) { - if (!cn->hasAllOutputsReady()) + if (!node->hasAllOutputsReady()) { in_use = true; } } for (auto name : block->GetOutputs()) { - if (name == ct->getName()) + if (name == tensor->getName()) { in_use = true; } } if (!in_use) { - ct->deallocate(); + tensor->deallocate(); } } @@ -433,29 +434,29 @@ int SubgraphTraverser::clearAllNodeMarkings() return false; } -int SubgraphTraverser::addTensor(TosaReference::Tensor* ct) +int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor) { // Enforce no duplicate tensors/tensor names // O(N), but the number of tensors is small for (TosaReference::Tensor* currTensor : tensors) { - if (ct == currTensor || currTensor->getName() == ct->getName()) + if (tensor == currTensor || currTensor->getName() == tensor->getName()) { - FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", ct->getName().c_str()); + FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", tensor->getName().c_str()); return 1; } } - tensors.push_back(ct); + tensors.push_back(tensor); - if (ct->getIsSubgraphInput()) + if (tensor->getIsSubgraphInput()) { - inputTensors.push_back(ct); + inputTensors.push_back(tensor); } - if (ct->getIsSubgraphOutput()) + if (tensor->getIsSubgraphOutput()) { - outputTensors.push_back(ct); + outputTensors.push_back(tensor); } return 0; diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index da81bcd..1efebe3 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -22,17 +22,11 @@ using namespace tosa; TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, - std::vector<int> shape_, - int isConst_) + std::vector<int> shape_) { tensorName = std::string(tensorName_); tensorDtype = tensorDtype_; - tensorUsage = std::vector<Usage>(tensorUsage_); - tensorFormat = std::vector<Format>(tensorFormat_); shape = std::vector<int>(shape_); - isConst = isConst_; producer = nullptr; isValid = false; consumers.clear(); @@ -74,17 +68,16 @@ int TosaReference::Tensor::addConsumer(GraphNode* node) int TosaReference::Tensor::dumpTensorParams(FILE* out) const { - fprintf(out, "Name: %s DType=%s Usage=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), - EnumNamesDType()[getDtype()], getUsageAsString().c_str(), getIsValid(), getRank(), - getShapeAsString().c_str()); + fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNamesDType()[getDtype()], + getIsValid(), getRank(), getShapeAsString().c_str()); return 0; } int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const { - out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " Usage=" << getUsageAsString() - << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n"; + out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " isValid=" << getIsValid() + << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n"; return 0; } diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 4f77cfc..d39cc7c 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -35,10 +35,7 @@ class Tensor public: Tensor(std::string tensorName_, DType tensorDtype__, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, - std::vector<int> shape_, - int isConst_); + std::vector<int> shape_); virtual ~Tensor(); @@ -75,11 +72,6 @@ public: return isValid; } - int getIsConst() const - { - return isConst; - } - GraphNode* getProducer() { return producer; @@ -111,62 +103,6 @@ public: return shape_str; } - const std::vector<Usage>& getUsage() const - { - return tensorUsage; - } - - bool hasUsage(Usage usage) const - { - for (auto& usg : tensorUsage) - { - if (usg == usage) - { - return true; - } - } - return false; - } - - std::string getUsageAsString() const - { - std::string usage_str("["); - for (auto& usg : tensorUsage) - { - usage_str += (std::string(EnumNamesUsage()[usg]) + ", "); - } - usage_str.append("]"); - return usage_str; - } - - const std::vector<Format>& getFormat() const - { - return tensorFormat; - } - - bool hasFormat(Format format) const - { - for (auto& fmt : tensorFormat) - { - if (fmt == format) - { - return true; - } - } - return false; - } - - std::string getFormatAsString() const - { - std::string format_str("["); - for (auto& fmt : tensorFormat) - { - format_str += (std::string(EnumNamesFormat()[fmt]) + ", "); - } - format_str.append("]"); - return format_str; - } - const uint32_t getElementCount() const { uint32_t elements = 1; @@ -282,9 +218,6 @@ public: protected: std::string tensorName; DType tensorDtype; - std::vector<Usage> tensorUsage; - std::vector<Format> tensorFormat; - int isConst; int isValid; std::vector<int> shape; int isSubgraphInput; @@ -309,11 +242,8 @@ class TensorTemplate : public Tensor public: TensorTemplate(std::string tensorName_, DType tensorDtype_, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, - std::vector<int> shape_, - int isConst_) - : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_) + std::vector<int> shape_) + : Tensor(tensorName_, tensorDtype_, shape_) { tensor = nullptr; } @@ -678,10 +608,7 @@ class TensorFactory public: static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, - const std::vector<Usage>& tensorUsage_, - const std::vector<Format>& tensorFormat_, std::vector<int> shape_, - int isConst_, const uint32_t rank) { switch (tensorDtype_) @@ -690,26 +617,19 @@ public: switch (rank) { case 0: - return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<float>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<float>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<float>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<float>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<float>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<float>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<float>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -721,26 +641,19 @@ public: switch (rank) { case 0: - return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -748,26 +661,19 @@ public: switch (rank) { case 0: - return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_); default: goto done; } @@ -775,26 +681,19 @@ public: switch (rank) { case 0: - return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor0<bool>(tensorName_, tensorDtype_, shape_); case 1: - return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor1<bool>(tensorName_, tensorDtype_, shape_); case 2: - return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor2<bool>(tensorName_, tensorDtype_, shape_); case 3: - return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor3<bool>(tensorName_, tensorDtype_, shape_); case 4: - return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor4<bool>(tensorName_, tensorDtype_, shape_); case 5: - return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor5<bool>(tensorName_, tensorDtype_, shape_); case 6: - return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, - isConst_); + return new Tensor6<bool>(tensorName_, tensorDtype_, shape_); default: goto done; } |