diff options
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index d5372a598b..0d425257e8 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1949,6 +1949,7 @@ ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& no ResizeBilinearDescriptor desc; desc.m_TargetHeight = static_cast<uint32_t> (sizeTensorData[0]); desc.m_TargetWidth = static_cast<uint32_t> (sizeTensorData[1]); + desc.m_DataLayout = armnn::DataLayout::NHWC; IConnectableLayer* layer = m_Network->AddResizeBilinearLayer(desc, nodeDef.name().c_str()); @@ -1960,13 +1961,12 @@ ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& no unsigned int outChannels = inputTensorInfo.GetShape()[3]; unsigned int outHeight = desc.m_TargetHeight; unsigned int outWidth = desc.m_TargetWidth; - TensorShape outShape({outBatch, outChannels, outHeight, outWidth}); + TensorShape outShape({outBatch, outHeight, outWidth, outChannels }); // The output DataType is always Float32, regardless of the input DataType. const TensorInfo outputTensorInfo(outShape, armnn::DataType::Float32); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); - // TensorFlow ResizeBilinear input is always in BHWC format, so add swizzle and deswizzle layers. - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } |