aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp98
1 files changed, 77 insertions, 21 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 3432d9f8..e5f99ed4 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -3034,26 +3034,72 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
- ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1);
- ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+ LayerInputHandle weightsInput = LayerInputHandle();
+ const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, 1, model);
+ if (!weightsOperand)
+ {
+ return Fail("%s: Could not read weights", __func__);
+ }
+
+ const armnn::TensorInfo& weightsInfo = GetTensorInfoForOperand(*weightsOperand);
+ bool constantWeights = IsOperandConstant<HalPolicy>(*weightsOperand);
- if (!weightsPin.IsValid())
+ armnn::Optional<armnn::ConstTensor> optionalWeights = armnn::EmptyOptional();
+ if (!constantWeights)
{
- return Fail("%s: Operation has invalid weights", __func__);
+ weightsInput = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
+ if (!weightsInput.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+ }
+ else
+ {
+ ConstTensorPin weightsPin = DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1);
+ if (!weightsPin.IsValid())
+ {
+ return Fail("%s: Operation has invalid weights", __func__);
+ }
+ optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsPin.GetConstTensor());
}
- if (!biasPin.IsValid())
+ LayerInputHandle biasInput = LayerInputHandle();
+ const HalOperand* biasOperand = GetInputOperand<HalPolicy>(operation, 2, model);
+ if (!biasOperand)
{
- return Fail("%s: Operation has invalid bias", __func__);
+ return Fail("%s: Could not read bias", __func__);
}
+ armnn::TensorInfo biasInfo = GetTensorInfoForOperand(*biasOperand);
+ bool constantBias = IsOperandConstant<HalPolicy>(*biasOperand);
- armnn::ConstTensor weights = weightsPin.GetConstTensor();
- armnn::ConstTensor bias = biasPin.GetConstTensor();
- armnn::TensorInfo reshapedInfo = inputInfo;
+ armnn::Optional<armnn::ConstTensor> optionalBias = armnn::EmptyOptional();
+ if (!constantBias)
+ {
+ biasInput = ConvertToLayerInputHandle<HalPolicy>(operation, 2, model, data);
+ if (!biasInput.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+ }
+ else
+ {
+ ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data); // 1D
+ if (!biasPin.IsValid())
+ {
+ return Fail("%s: Operation has invalid bias", __func__);
+ }
+ optionalBias = armnn::Optional<armnn::ConstTensor>(biasPin.GetConstTensor());
+ }
+ if ((constantWeights && !constantBias) || (!constantWeights && constantBias))
+ {
+ return Fail("%s: Non-compatible weights and bias", __func__);
+ }
+
+ armnn::TensorInfo reshapedInfo = inputInfo;
try
{
- reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weights.GetInfo().GetShape()));
+ reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weightsInfo.GetShape()));
}
catch (const std::exception& e)
{
@@ -3061,7 +3107,7 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
}
// ensuring that the bias value is within 1% of the weights input (small float differences can exist)
- SanitizeBiasQuantizationScale(bias.GetInfo(), weights.GetInfo(), reshapedInfo);
+ SanitizeBiasQuantizationScale(biasInfo, weightsInfo, reshapedInfo);
ActivationFn activationFunction;
if (!GetInputActivationFunction<HalPolicy>(operation, 3, activationFunction, model, data))
@@ -3072,12 +3118,13 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
armnn::FullyConnectedDescriptor desc;
desc.m_TransposeWeightMatrix = true;
desc.m_BiasEnabled = true;
+ desc.m_ConstantWeights = constantWeights;
bool isSupported = false;
auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
{
if (!VerifyFullyConnectedShapes(reshapedInfo.GetShape(),
- weights.GetInfo().GetShape(),
+ weightsInfo.GetShape(),
outputInfo.GetShape(),
desc.m_TransposeWeightMatrix))
{
@@ -3087,14 +3134,14 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
}
FORWARD_LAYER_SUPPORT_FUNC(__func__,
- IsFullyConnectedSupported,
- data.m_Backends,
- isSupported,
- reshapedInfo,
- outputInfo,
- weights.GetInfo(),
- bias.GetInfo(),
- desc);
+ IsFullyConnectedSupported,
+ data.m_Backends,
+ isSupported,
+ reshapedInfo,
+ outputInfo,
+ weightsInfo,
+ biasInfo,
+ desc);
};
if(!IsDynamicTensor(outputInfo))
@@ -3112,7 +3159,9 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
}
armnn::IConnectableLayer* startLayer =
- data.m_Network->AddFullyConnectedLayer(desc, weights, armnn::Optional<armnn::ConstTensor>(bias));
+ data.m_Network->AddFullyConnectedLayer(desc,
+ optionalWeights,
+ optionalBias);
if (inputInfo.GetNumDimensions() > 2U)
{
@@ -3130,6 +3179,13 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
input.Connect(startLayer->GetInputSlot(0));
}
+ // connect weights input
+ if (!desc.m_ConstantWeights)
+ {
+ weightsInput.Connect(startLayer->GetInputSlot(1));
+ biasInput.Connect(startLayer->GetInputSlot(2));
+ }
+
return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *startLayer, model,
data, nullptr, validateFunc, activationFunction);
}