aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/SerializeLayerParameters.cpp4
-rw-r--r--src/armnn/layers/ResizeLayer.cpp5
2 files changed, 9 insertions, 0 deletions
diff --git a/src/armnn/SerializeLayerParameters.cpp b/src/armnn/SerializeLayerParameters.cpp
index 76b92f3f9d..e4bf094b7c 100644
--- a/src/armnn/SerializeLayerParameters.cpp
+++ b/src/armnn/SerializeLayerParameters.cpp
@@ -293,6 +293,8 @@ void StringifyLayerParameters<ResizeBilinearDescriptor>::Serialize(ParameterStri
fn("TargetWidth", std::to_string(desc.m_TargetWidth));
fn("TargetHeight", std::to_string(desc.m_TargetHeight));
fn("DataLayout", GetDataLayoutName(desc.m_DataLayout));
+ fn("AlignCorners", std::to_string(desc.m_AlignCorners));
+ fn("HalfPixelCenters", std::to_string(desc.m_HalfPixelCenters));
}
void StringifyLayerParameters<ResizeDescriptor>::Serialize(ParameterStringifyFunction& fn,
@@ -302,6 +304,8 @@ void StringifyLayerParameters<ResizeDescriptor>::Serialize(ParameterStringifyFun
fn("TargetHeight", std::to_string(desc.m_TargetHeight));
fn("ResizeMethod", GetResizeMethodAsCString(desc.m_Method));
fn("DataLayout", GetDataLayoutName(desc.m_DataLayout));
+ fn("AlignCorners", std::to_string(desc.m_AlignCorners));
+ fn("HalfPixelCenters", std::to_string(desc.m_HalfPixelCenters));
}
void StringifyLayerParameters<SpaceToBatchNdDescriptor>::Serialize(ParameterStringifyFunction& fn,
diff --git a/src/armnn/layers/ResizeLayer.cpp b/src/armnn/layers/ResizeLayer.cpp
index 9654e58b43..b16adeb860 100644
--- a/src/armnn/layers/ResizeLayer.cpp
+++ b/src/armnn/layers/ResizeLayer.cpp
@@ -50,6 +50,11 @@ std::vector<TensorShape> ResizeLayer::InferOutputShapes(const std::vector<Tensor
TensorShape( { outBatch, outHeight, outWidth, outChannels } ) :
TensorShape( { outBatch, outChannels, outHeight, outWidth });
+ if (m_Param.m_HalfPixelCenters && m_Param.m_AlignCorners)
+ {
+ throw LayerValidationException("ResizeLayer: AlignCorners cannot be true when HalfPixelCenters is true");
+ }
+
return std::vector<TensorShape>({ tensorShape });
}