aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 289f780fba..37fda3e210 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1246,7 +1246,13 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
{
const std::string descriptorName{"Convolution2dQueueDescriptor"};
- ValidateNumInputs(workloadInfo, descriptorName, 1);
+ uint32_t numInputs = 2;
+ if (m_Parameters.m_BiasEnabled)
+ {
+ numInputs = 3;
+ }
+
+ ValidateNumInputs(workloadInfo, descriptorName, numInputs);
ValidateNumOutputs(workloadInfo, descriptorName, 1);
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
@@ -1255,9 +1261,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
- ValidatePointer(m_Weight, descriptorName, "weight");
+ const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
- const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
@@ -1265,9 +1270,7 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
Optional<TensorInfo> optionalBiasTensorInfo;
if (m_Parameters.m_BiasEnabled)
{
- ValidatePointer(m_Bias, descriptorName, "bias");
-
- optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
+ optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");