diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 90db57f953..2c5303c019 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1022,7 +1022,16 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c { const std::string descriptorName{"FullyConnectedQueueDescriptor"}; - ValidateNumInputs(workloadInfo, descriptorName, 1); + uint32_t numInputs = 1; + if (!m_Parameters.m_ConstantWeights) + { + numInputs = 2; + if (m_Parameters.m_BiasEnabled) + { + numInputs = 3; + } + } + ValidateNumInputs(workloadInfo, descriptorName, numInputs); ValidateNumOutputs(workloadInfo, descriptorName, 1); const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; @@ -1035,19 +1044,32 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions."); } - ValidatePointer(m_Weight, descriptorName, "weight"); - - const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo(); + TensorInfo weightTensorInfo; + if (m_Parameters.m_ConstantWeights) + { + ValidatePointer(m_Weight, descriptorName, "weight"); + weightTensorInfo = m_Weight->GetTensorInfo(); + } + else + { + weightTensorInfo = workloadInfo.m_InputTensorInfos[1]; + } ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight"); if (m_Parameters.m_BiasEnabled) { - ValidatePointer(m_Bias, descriptorName, "bias"); - + TensorInfo biasTensorInfo; + if (m_Parameters.m_ConstantWeights) + { + ValidatePointer(m_Bias, descriptorName, "bias"); + biasTensorInfo = m_Bias->GetTensorInfo(); + } + else + { + biasTensorInfo = workloadInfo.m_InputTensorInfos[2]; + } // Validates type and quantization values. - const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo(); ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName); - ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias"); ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias"); } |