aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp109
1 files changed, 106 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index e1a369af7c..201cc7d1ec 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -338,6 +338,102 @@ void ValidateTensorNumElementsMatch(const TensorInfo& first,
}
}
+void ValidateWeightDataType(const TensorInfo& inputInfo,
+ const TensorInfo& weightInfo,
+ const std::string& descName)
+{
+ const DataType inputType = inputInfo.GetDataType();
+ if (inputType == DataType::QuantisedAsymm8)
+ {
+ const std::vector<DataType> validTypes =
+ {
+ DataType::QuantisedAsymm8,
+ DataType::QuantizedSymm8PerAxis
+ };
+
+ ValidateDataTypes(weightInfo, validTypes, descName);
+ }
+ else
+ {
+ ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
+ }
+}
+
+void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
+ const std::string& descName,
+ const std::string& tensorName)
+{
+ const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
+ if (!quantizationDim.has_value())
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("%1%: Quantization dimension for per-axis quantization not set on tensor %2%.")
+ % descName % tensorName));
+ }
+
+ if (quantizationDim.value() != 0)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("%1%: Quantization dimension for per-axis quantization expected to be 0 on tensor %2%, "
+ "but got: %3%") % descName % tensorName % quantizationDim.value()));
+ }
+}
+
+void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
+ const std::string& descName,
+ const std::string& tensorName)
+{
+ int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
+ if (quantizationOffset != 0)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("%1%: Quantization offset for per-axis quantization expected to be 0 on tensor %2%, "
+ "but got: %3%") % descName % tensorName % quantizationOffset));
+ }
+}
+
+void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
+ const TensorInfo& outputInfo,
+ const TensorInfo& weightInfo,
+ const Optional<TensorInfo>& optionalBiasInfo,
+ const std::string& descName)
+{
+ if (weightInfo.HasPerAxisQuantization())
+ {
+ const DataType inputDataType = inputInfo.GetDataType();
+ const DataType outputDataType = outputInfo.GetDataType();
+
+ const bool canHavePerAxisQuantization =
+ inputDataType == DataType::QuantisedAsymm8 && inputDataType == outputDataType;
+
+ if (!canHavePerAxisQuantization)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("%1%: Per-axis quantization parameters set on tensor %2%, "
+ "but data type does not support per-axis quantization.") % descName % "weight"));
+ }
+
+ ValidateTensorDataType(weightInfo, DataType::QuantizedSymm8PerAxis, descName, "weight");
+ ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
+ ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
+
+ if (optionalBiasInfo.has_value())
+ {
+ const TensorInfo& biasInfo = optionalBiasInfo.value();
+ if (!biasInfo.HasPerAxisQuantization())
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("%1%: Per-axis quantization parameters not set on bias tensor, despite being set on "
+ "weight tensor.") % descName));
+ }
+
+ ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
+ ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
+ ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
+ }
+ }
+}
+
} // anonymous namespace
void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
@@ -1040,19 +1136,26 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
- ValidateTensorDataTypesMatch(inputTensorInfo, weightTensorInfo, descriptorName, "input", "weight");
+ ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
+ Optional<TensorInfo> optionalBiasTensorInfo;
if (m_Parameters.m_BiasEnabled)
{
ValidatePointer(m_Bias, descriptorName, "bias");
- const TensorInfo& biasTensorInfo = m_Bias->GetTensorInfo();
- ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
+ optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
+ const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
}
+ ValidatePerAxisQuantization(inputTensorInfo,
+ outputTensorInfo,
+ weightTensorInfo,
+ optionalBiasTensorInfo,
+ descriptorName);
+
std::vector<DataType> supportedTypes =
{
DataType::Float32,