From f0a6dec75832604d5ab18242dc216852821a8279 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Thu, 25 Mar 2021 07:46:55 +0000 Subject: IVGCVSW-5736 and IVGCVSW-5743 'NonConstWeights: Update front-end and TfLiteDelegate support for FullyConnected Operator' * Added front-end support for non-const weights for FULLY_CONNECTED operator * Added FULLY_CONNECTED end-to-end test * Updated FULLY_CONNECTED operator support in TfLite Arm NN Delegate for non-const weights * Updated the version numbers Signed-off-by: Sadik Armagan Change-Id: Iffa5b9aa9297aca4c02d923cce4636c88ac21faa --- src/backends/backendsCommon/WorkloadData.cpp | 38 ++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 90db57f953..2c5303c019 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1022,7 +1022,16 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c { const std::string descriptorName{"FullyConnectedQueueDescriptor"}; - ValidateNumInputs(workloadInfo, descriptorName, 1); + uint32_t numInputs = 1; + if (!m_Parameters.m_ConstantWeights) + { + numInputs = 2; + if (m_Parameters.m_BiasEnabled) + { + numInputs = 3; + } + } + ValidateNumInputs(workloadInfo, descriptorName, numInputs); ValidateNumOutputs(workloadInfo, descriptorName, 1); const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; @@ -1035,19 +1044,32 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions."); } - ValidatePointer(m_Weight, descriptorName, "weight"); - - const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo(); + TensorInfo weightTensorInfo; + if (m_Parameters.m_ConstantWeights) + { + ValidatePointer(m_Weight, descriptorName, "weight"); + weightTensorInfo = m_Weight->GetTensorInfo(); + } + else + { + weightTensorInfo = workloadInfo.m_InputTensorInfos[1]; + } ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight"); if (m_Parameters.m_BiasEnabled) { - ValidatePointer(m_Bias, descriptorName, "bias"); - + TensorInfo biasTensorInfo; + if (m_Parameters.m_ConstantWeights) + { + ValidatePointer(m_Bias, descriptorName, "bias"); + biasTensorInfo = m_Bias->GetTensorInfo(); + } + else + { + biasTensorInfo = workloadInfo.m_InputTensorInfos[2]; + } // Validates type and quantization values. - const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo(); ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName); - ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias"); ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias"); } -- cgit v1.2.1