diff options
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp')
-rw-r--r-- | src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp b/src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp index 9c095d627f..dadd91b227 100644 --- a/src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp +++ b/src/backends/tosaCommon/operatorMappings/Conv2dOperator.cpp @@ -39,19 +39,23 @@ TosaSerializationBasicBlock* ConvertConv2dToTosaOperator(const Layer* layer, } // Get the layer connected to the output slot and determine unique layer name. - Layer& connectedLayer = layer->GetOutputSlot().GetConnection(0)->GetOwningLayer(); - - outputName = GenerateUniqueName(connectedLayer, 0); + outputName = GenerateUniqueOutputName(*layer, 0); } std::vector<TosaSerializationTensor*> tensors; std::vector<TosaSerializationOperator*> operators; // Setup input Tensor - std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape()); - DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType()); + // Only add tensor if connected layer is an input layer. + // As intermediate or constant tensors will be created separately. + // There also can't be duplicate tensors. + if(inputNames[0].find("input0_") != std::string::npos) + { + std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape()); + DType inputDType0 = ArmNNToDType(inputs[0]->GetDataType()); - tensors.push_back(new TosaSerializationTensor(inputNames[0], inputShape0, inputDType0, {})); + tensors.push_back(new TosaSerializationTensor(inputNames[0], inputShape0, inputDType0, {})); + } // Only add input tensors if weights and bias are not constant or if running validation. // Constant tensors will be created in the ConvertConstantToTosaOperator function. @@ -80,12 +84,18 @@ TosaSerializationBasicBlock* ConvertConv2dToTosaOperator(const Layer* layer, operators.push_back(new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {constantName})); + // The size of the bias must match the channels dimension, so get the correct index. + unsigned int index = (conv2dDescriptor->m_DataLayout == DataLayout::NHWC) ? 3 : 1; + std::vector<uint8_t> uint8Data; - std::vector<float> data = { 0.0 }; + std::vector<float> data(outputs[0]->GetShape()[index], 0.0f); TosaSerializationHandler::ConvertF32toU8(data, uint8Data); - tensors.push_back(new TosaSerializationTensor(constantName, {1}, DType_FP32, uint8Data)); + tensors.push_back(new TosaSerializationTensor(constantName, + {static_cast<int32_t>(outputs[0]->GetShape()[index])}, + DType_FP32, + uint8Data)); inputNames.emplace_back(constantName); } |