diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 39 |
1 files changed, 34 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a1d00c6945..1505078b77 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -271,6 +271,20 @@ void ValidateDataTypes(const TensorInfo& info, } } +//--------------------------------------------------------------- +void ValidateTensorDataTypesMatch(const TensorInfo& first, + const TensorInfo& second, + std::string const& descName, + std::string const& firstName, + std::string const& secondName) +{ + if (first.GetDataType() != second.GetDataType()) + { + throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName + + " must have identical data types."); + } +} + } //namespace void QueueDescriptor::ValidateInputsOutputs(const std::string& descName, @@ -1275,25 +1289,40 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1); - ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1); + const std::string meanQueueDescString = "MeanQueueDescriptor"; + + ValidateNumInputs(workloadInfo, meanQueueDescString, 1); + ValidateNumOutputs(workloadInfo, meanQueueDescString, 1); + + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; + // First check if input tensor data type is supported, then + // check if this data type matches the output tensor data type + ValidateDataTypes(input, supportedTypes, meanQueueDescString); + ValidateTensorDataTypesMatch(input, output, meanQueueDescString, "input", "output"); + if (m_Parameters.m_KeepDims) { - ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output"); + ValidateTensorNumDimensions(output, meanQueueDescString, input.GetNumDimensions(), "output"); } else if (m_Parameters.m_Axis.empty()) { - ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output"); + ValidateTensorNumDimensions(output, meanQueueDescString, 1, "output"); } else { auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size()); ValidateTensorNumDimensions(output, - "MeanQueueDescriptor", + meanQueueDescString, outputDim > 0 ? outputDim : 1, "output"); } |