aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2021-07-13 19:46:11 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2021-08-06 09:25:26 +0000
commit81beae3a870004795275e9266bc43d845b9f78db (patch)
tree70af86f3c36c8e330c72770e6f1419ca7b2a4bb8 /src/backends/backendsCommon/WorkloadData.cpp
parent95e9efc28ce70a8cda93e722f5ce90ebc96bdd95 (diff)
downloadarmnn-81beae3a870004795275e9266bc43d845b9f78db.tar.gz
IVGCVSW-6119 ConstTensorsAsInput: FullyConnected
* Constant weights and biases are now stored as Constant layers. * Updated Serializer, Deserializer and unit tests to reflect this. * Updated TfLiteDelegate, TfLiteParser and OnnxParser. * Updated Schema with IsConstant and ConstantTensorsAsInputs. * Updated Ref backend to handle constant weights and bias as inputs rather than reading from member variables. * Added dynamic or constant input EndToEnd tests. !android-nn-driver:5959 Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: Ibf3cf437df1100e4b322b0d303c575c6339f9696
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp41
1 files changed, 9 insertions, 32 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 3fe0823b03..319cdb106b 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1041,15 +1041,12 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
{
const std::string descriptorName{"FullyConnectedQueueDescriptor"};
- uint32_t numInputs = 1;
- if (!m_Parameters.m_ConstantWeights)
+ uint32_t numInputs = 2;
+ if (m_Parameters.m_BiasEnabled)
{
- numInputs = 2;
- if (m_Parameters.m_BiasEnabled)
- {
- numInputs = 3;
- }
+ numInputs = 3;
}
+
ValidateNumInputs(workloadInfo, descriptorName, numInputs);
ValidateNumOutputs(workloadInfo, descriptorName, 1);
@@ -1063,30 +1060,12 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
}
- TensorInfo weightTensorInfo;
- if (m_Parameters.m_ConstantWeights)
- {
- ValidatePointer(m_Weight, descriptorName, "weight");
- weightTensorInfo = m_Weight->GetTensorInfo();
- }
- else
- {
- weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
- }
+ TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
if (m_Parameters.m_BiasEnabled)
{
- TensorInfo biasTensorInfo;
- if (m_Parameters.m_ConstantWeights)
- {
- ValidatePointer(m_Bias, descriptorName, "bias");
- biasTensorInfo = m_Bias->GetTensorInfo();
- }
- else
- {
- biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
- }
+ TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
// Validates type and quantization values.
ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
@@ -1894,11 +1873,9 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
-
- if (inputTensorInfo != outputTensorInfo)
- {
- throw InvalidArgumentException(descriptorName + ": Input and output tensor infos do not match.");
- }
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const