diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 26 |
1 files changed, 4 insertions, 22 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 70d7641f41..a2dcd63726 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3860,38 +3860,20 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor { throw InvalidArgumentException(descriptorName + ": Invalid number of inputs."); } - if (workloadInfo.m_OutputTensorInfos.size() != 1) + if (workloadInfo.m_OutputTensorInfos.size() != 3) { throw InvalidArgumentException(descriptorName + ": Invalid number of outputs."); } std::vector<DataType> supportedTypes = { - DataType::Float32 + DataType::Float32, + DataType::QAsymmS8 }; // check for supported type of one input and match them with all the other input and output ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); - // type matches all other inputs - for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i) - { - ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], - workloadInfo.m_InputTensorInfos[i], - descriptorName, - "input_0", - "input_" + std::to_string(i)); - } - // type matches all other outputs - for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i) - { - ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], - workloadInfo.m_OutputTensorInfos[i], - "LstmQueueDescriptor", - "input_0", - "output_" + std::to_string(i)); - } - // Making sure clipping parameters have valid values. // == 0 means no clipping // > 0 means clipping @@ -3936,7 +3918,7 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor descriptorName + " input_2"); // outputTensor - ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output), + ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output), descriptorName + " output_0"); // check that dimensions of inputs/outputs and QueueDescriptor data match with each other |