aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnnTfParser/TfParser.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index d5372a598b..0d425257e8 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1949,6 +1949,7 @@ ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& no
ResizeBilinearDescriptor desc;
desc.m_TargetHeight = static_cast<uint32_t> (sizeTensorData[0]);
desc.m_TargetWidth = static_cast<uint32_t> (sizeTensorData[1]);
+ desc.m_DataLayout = armnn::DataLayout::NHWC;
IConnectableLayer* layer = m_Network->AddResizeBilinearLayer(desc, nodeDef.name().c_str());
@@ -1960,13 +1961,12 @@ ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& no
unsigned int outChannels = inputTensorInfo.GetShape()[3];
unsigned int outHeight = desc.m_TargetHeight;
unsigned int outWidth = desc.m_TargetWidth;
- TensorShape outShape({outBatch, outChannels, outHeight, outWidth});
+ TensorShape outShape({outBatch, outHeight, outWidth, outChannels });
// The output DataType is always Float32, regardless of the input DataType.
const TensorInfo outputTensorInfo(outShape, armnn::DataType::Float32);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
- // TensorFlow ResizeBilinear input is always in BHWC format, so add swizzle and deswizzle layers.
- layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+ inputSlot.Connect(layer->GetInputSlot(0));
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}