aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-01-28 16:18:54 +0000
committermike.kelly <mike.kelly@arm.com>2022-04-22 15:46:50 +0000
commit5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch)
treeb256346d6fc78e78735cc50ec822286f809dd37f
parent4dae5794644b44be8c93bc6db553a205551bc077 (diff)
downloadarmnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser * Added support for float operations with int8 weights to TFLite Parser * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected, TransposeConv and UnidirectionalSequenceLstm * Renamed subgraphIndex to subgraph to fix name-shadowing warning. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
-rw-r--r--include/armnn/Descriptors.hpp38
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp476
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp15
-rw-r--r--src/armnnTfLiteParser/test/Conv2D.cpp72
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp36
5 files changed, 580 insertions, 57 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 280c18e78c..4c2242e1ad 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1086,17 +1086,29 @@ struct LstmDescriptor : BaseDescriptor
, m_ProjectionEnabled(false)
, m_LayerNormEnabled(false)
, m_TimeMajor(false)
+ , m_InputIntermediateScale(0.0)
+ , m_ForgetIntermediateScale(0.0)
+ , m_CellIntermediateScale(0.0)
+ , m_OutputIntermediateScale(0.0)
+ , m_HiddenStateZeroPoint(0)
+ , m_HiddenStateScale(0.0)
{}
bool operator ==(const LstmDescriptor& rhs) const
{
- return m_ActivationFunc == rhs.m_ActivationFunc &&
- m_ClippingThresCell == rhs.m_ClippingThresCell &&
- m_ClippingThresProj == rhs.m_ClippingThresProj &&
- m_CifgEnabled == rhs.m_CifgEnabled &&
- m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
- m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
- m_TimeMajor == rhs.m_TimeMajor;
+ return m_ActivationFunc == rhs.m_ActivationFunc &&
+ m_ClippingThresCell == rhs.m_ClippingThresCell &&
+ m_ClippingThresProj == rhs.m_ClippingThresProj &&
+ m_CifgEnabled == rhs.m_CifgEnabled &&
+ m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
+ m_TimeMajor == rhs.m_TimeMajor &&
+ m_InputIntermediateScale == rhs.m_InputIntermediateScale &&
+ m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale &&
+ m_CellIntermediateScale == rhs.m_CellIntermediateScale &&
+ m_OutputIntermediateScale == rhs.m_OutputIntermediateScale &&
+ m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint &&
+ m_HiddenStateScale == rhs.m_HiddenStateScale;
}
/// @brief The activation function to use.
@@ -1116,6 +1128,18 @@ struct LstmDescriptor : BaseDescriptor
bool m_LayerNormEnabled;
/// Enable/disable time major
bool m_TimeMajor;
+ /// Input intermediate quantization scale
+ float m_InputIntermediateScale;
+ /// Forget intermediate quantization scale
+ float m_ForgetIntermediateScale;
+ /// Cell intermediate quantization scale
+ float m_CellIntermediateScale;
+ /// Output intermediate quantization scale
+ float m_OutputIntermediateScale;
+ /// Hidden State zero point
+ int32_t m_HiddenStateZeroPoint;
+ /// Hidden State quantization scale
+ float m_HiddenStateScale;
};
using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index fddd93adc8..44dcacc3db 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -6,6 +6,7 @@
#include "TfLiteParser.hpp"
#include "armnnTfLiteParser/Version.hpp"
+#include "armnn/LstmParams.hpp"
#include <armnn/BackendOptions.hpp>
#include <armnn/Descriptors.hpp>
@@ -33,11 +34,9 @@
#include <fmt/format.h>
#include <algorithm>
-#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
-#include <sstream>
#define ARMNN_THROW_PARSE_EXCEPTION(msg) \
{ \
@@ -375,6 +374,11 @@ std::vector<unsigned int> AsUnsignedVector(const std::vector<int32_t>& in)
return result;
}
+bool IsOptionalOperandPresent(int input)
+{
+ return (input >= 0);
+}
+
void CalcPadding(uint32_t inputSize,
uint32_t filterSize,
uint32_t stride,
@@ -738,6 +742,8 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_TANH] = &TfLiteParserImpl::ParseTanH;
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE] = &TfLiteParserImpl::ParseTranspose;
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParserImpl::ParseTransposeConv;
+ m_ParserFunctions[tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM]
+ = &TfLiteParserImpl::ParseUnidirectionalSequenceLSTM;
m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParserImpl::ParseUnpack;
// register supported custom operators
@@ -749,6 +755,9 @@ void TfLiteParserImpl::ResetParser()
m_Network = armnn::INetworkPtr(nullptr, nullptr);
m_Model = nullptr;
m_SubgraphConnections.clear();
+ m_OverridenOutputShapes.clear();
+ m_ConstantsToDequantize.clear();
+ m_ConstantsToBeCreated.clear();
}
INetworkPtr TfLiteParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
@@ -869,10 +878,47 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel()
}
}
}
-
return std::move(m_Network);
}
+std::unique_ptr<float[]> AsFloatArray(TfLiteParserImpl::BufferRawPtr bufferPtr,
+ const TensorInfo& tensorInfo)
+{
+ if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8 ||
+ tensorInfo.GetDataType() == DataType::QAsymmU8)
+ {
+ std::unique_ptr<float[]> buffer(new float[tensorInfo.GetNumElements()]);
+
+ if (tensorInfo.HasPerAxisQuantization())
+ {
+ unsigned int axis = tensorInfo.GetQuantizationDim().value();
+ auto axisDimensionality = tensorInfo.GetShape()[axis];
+ auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis);
+
+ for (unsigned int i = 0; i < tensorInfo.GetNumDimensions(); ++i)
+ {
+ unsigned int axisIndex = (i / axisFactor) % axisDimensionality;
+ buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScales()[axisIndex],
+ tensorInfo.GetQuantizationOffset());
+ }
+ }
+ else
+ {
+ for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
+ {
+ buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScale(),
+ tensorInfo.GetQuantizationOffset());
+ }
+ }
+ return buffer;
+ }
+ throw ParseException(
+ fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}",
+ GetDataTypeName(DataType::Float32),
+ GetDataTypeName(tensorInfo.GetDataType()),
+ CHECK_LOCATION().AsString()));
+}
+
void TfLiteParserImpl::RegisterProducerOfTensor(size_t subgraphIndex,
size_t tensorIndex,
armnn::IOutputSlot* slot)
@@ -1050,7 +1096,7 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
- auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo);
+ auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
armnn::IConnectableLayer* layer = nullptr;
auto layerName = fmt::format("Conv2D:{}:{}", subgraphIndex, operatorIndex);
@@ -1059,16 +1105,16 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
{
desc.m_BiasEnabled = true;
armnn::TensorInfo biasTensorInfo = ToTensorInfo(inputs[2]);
- auto biasTensorAndData = CreateConstTensorNonPermuted(inputs[2], biasTensorInfo);
+ auto biasTensorAndData = CreateConstTensorNonPermuted(inputs[2], biasTensorInfo, inputTensorInfo.GetDataType());
layer = m_Network->AddConvolution2dLayer(desc,
- filterTensorAndData,
- Optional<ConstTensor>(biasTensorAndData),
+ filterTensorAndData.first,
+ Optional<ConstTensor>(biasTensorAndData.first),
layerName.c_str());
}
else
{
layer = m_Network->AddConvolution2dLayer(desc,
- filterTensorAndData,
+ filterTensorAndData.first,
EmptyOptional(),
layerName.c_str());
}
@@ -1136,7 +1182,7 @@ void TfLiteParserImpl::ParseConv3D(size_t subgraphIndex, size_t operatorIndex)
CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
- auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo);
+ auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
auto layerName = fmt::format("Conv3D:{}:{}", subgraphIndex, operatorIndex);
@@ -1209,7 +1255,7 @@ void TfLiteParserImpl::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operato
desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
// ArmNN uses the same filter tensor layout at TfLite [1, H, W, O] no need for any permutation
- auto filterTensor = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo);
+ auto filterTensor = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
armnn::IConnectableLayer* layer = nullptr;
auto layerName = fmt::format("DepthwiseConv2D:{}:{}", subgraphIndex, operatorIndex);
@@ -1217,16 +1263,16 @@ void TfLiteParserImpl::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operato
{
desc.m_BiasEnabled = true;
TensorInfo biasTensorInfo = ToTensorInfo(inputs[2]);
- auto biasTensorAndData = CreateConstTensorNonPermuted(inputs[2], biasTensorInfo);
+ auto biasTensorAndData = CreateConstTensorNonPermuted(inputs[2], biasTensorInfo, inputTensorInfo.GetDataType());
layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
- filterTensor,
- Optional<ConstTensor>(biasTensorAndData),
+ filterTensor.first,
+ Optional<ConstTensor>(biasTensorAndData.first),
layerName.c_str());
}
else
{
layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
- filterTensor,
+ filterTensor.first,
EmptyOptional(),
layerName.c_str());
}
@@ -1453,7 +1499,7 @@ void TfLiteParserImpl::ParseTransposeConv(size_t subgraphIndex, size_t operatorI
desc.m_PadRight,
options->padding);
- auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo);
+ auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
armnn::IConnectableLayer* layer = nullptr;
auto layerName = fmt::format("TransposeConv:{}:{}", subgraphIndex, operatorIndex);
@@ -1461,16 +1507,16 @@ void TfLiteParserImpl::ParseTransposeConv(size_t subgraphIndex, size_t operatorI
if (desc.m_BiasEnabled)
{
auto biasTensorInfo = ToTensorInfo(inputs[3]);
- auto biasConstTensor = CreateConstTensorNonPermuted(inputs[3], biasTensorInfo);
+ auto biasConstTensor = CreateConstTensorNonPermuted(inputs[3], biasTensorInfo, inputTensorInfo.GetDataType());
layer = m_Network->AddTransposeConvolution2dLayer(desc,
- filterTensorAndData,
- biasConstTensor,
+ filterTensorAndData.first,
+ biasConstTensor.first,
layerName.c_str());
}
else
{
layer = m_Network->AddTransposeConvolution2dLayer(desc,
- filterTensorAndData,
+ filterTensorAndData.first,
EmptyOptional(),
layerName.c_str());
}
@@ -2436,10 +2482,11 @@ void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex)
armnn::IInputSlot* slot = &(layer->GetInputSlot(0));
RegisterConsumerOfTensor(subgraphIndex, inputTensorIndexes[0], slot);
- auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo);
+ auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo,
+ inputTensorInfo.GetDataType());
std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name);
IConnectableLayer* constLayer =
- m_Network->AddConstantLayer(alphaTensorAndData, constLayerName.c_str());
+ m_Network->AddConstantLayer(alphaTensorAndData.first, constLayerName.c_str());
ARMNN_ASSERT(constLayer != nullptr);
constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
@@ -2933,25 +2980,40 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator
// Add the first input tensor to the registration list
std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0]};
std::vector<unsigned int> ignoreInputWhenRegister = {};
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
desc.m_ConstantWeights = IsConstTensor(inputs[1]);
// Add the weights input to the registration list, constant layers will be added by SetupConstantLayers if constant.
tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]);
+ if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 &&
+ (filterTensorInfo.GetDataType() == DataType::QAsymmU8 ||
+ filterTensorInfo.GetDataType() == DataType::QAsymmS8))
+ {
+ m_ConstantsToDequantize.emplace_back(inputs[1]->buffer);
+ }
+
if (inputs.size() == 3)
{
desc.m_BiasEnabled = true;
+ armnn::TensorInfo biasTensorInfo = ToTensorInfo(inputs[2]);
// Add the biases input to the registration list, constant layer will be added by SetupConstantLayers.
tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
+
+ if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 &&
+ (biasTensorInfo.GetDataType() == DataType::QAsymmU8 ||
+ biasTensorInfo.GetDataType() == DataType::QAsymmS8))
+ {
+ m_ConstantsToDequantize.emplace_back(inputs[2]->buffer);
+ }
}
// Filters and biases are always passed to fully connected as inputs
layer = m_Network->AddFullyConnectedLayer(desc, layerName.c_str());
ARMNN_ASSERT(layer != nullptr);
- armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
unsigned int startingSlotIndex = 0;
if (inputTensorInfo.GetNumDimensions() > 2)
@@ -3120,6 +3182,278 @@ void TfLiteParserImpl::ParsePack(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+
+ if (inputs.size() < 2)
+ {
+ throw ParseException("UnidirectionalSequenceLSTM must have at least 2 input.");
+ }
+
+ const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
+ const auto nodeParams = operatorPtr->builtin_options.AsUnidirectionalSequenceLSTMOptions();
+ CHECK_SUPPORTED_FUSED_ACTIVATION(nodeParams, subgraphIndex, operatorIndex);
+ auto inputTensorInfo = ToTensorInfo(inputs[0]);
+ auto outputTensorInfo = ToTensorInfo(outputs[0]);
+
+ // Set the params structure for the AddUnidirectionalSequenceLstmLayer call
+ // Please refer to each operand at
+ // https://www.tensorflow.org/mlir/tfl_ops#tflunidirectional_sequence_lstm_tflunidirectionalsequencelstmop
+ armnn::LstmInputParams params;
+
+ if (IsOptionalOperandPresent(operatorPtr->inputs[1]))
+ {
+ params.m_InputToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[1]].get(),
+ inputTensorInfo).first;
+ }
+
+ params.m_InputToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[2]].get(),
+ inputTensorInfo).first;
+ params.m_InputToCellWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[3]].get(),
+ inputTensorInfo).first;
+ params.m_InputToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[4]].get(),
+ inputTensorInfo).first;
+
+ // Recurrent weight tensors of size {n_cell, n_output}
+ if (IsOptionalOperandPresent(operatorPtr->inputs[5]))
+ {
+ params.m_RecurrentToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[5]].get(),
+ inputTensorInfo).first;
+ }
+
+ params.m_RecurrentToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[6]].get(),
+ inputTensorInfo).first;
+ params.m_RecurrentToCellWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[7]].get(),
+ inputTensorInfo).first;
+ params.m_RecurrentToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[8]].get(),
+ inputTensorInfo).first;
+
+ // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+ if (IsOptionalOperandPresent(operatorPtr->inputs[9]))
+ {
+ params.m_CellToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[9]].get(),
+ inputTensorInfo).first;
+ }
+
+ if (IsOptionalOperandPresent(operatorPtr->inputs[10]))
+ {
+ params.m_CellToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[10]].get(),
+ inputTensorInfo).first;
+ }
+
+ if (IsOptionalOperandPresent(operatorPtr->inputs[11]))
+ {
+ params.m_CellToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[11]].get(),
+ inputTensorInfo).first;
+ }
+
+ // Gates bias tensors of size {n_cell}
+ if (IsOptionalOperandPresent(operatorPtr->inputs[12]))
+ {
+ params.m_InputGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[12]].get(),
+ inputTensorInfo).first;
+ }
+
+ params.m_ForgetGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[13]].get(),
+ inputTensorInfo).first;
+ params.m_CellBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[14]].get(),
+ inputTensorInfo).first;
+ params.m_OutputGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[15]].get(),
+ inputTensorInfo).first;
+
+ // Projection weight tensor of size {n_output, n_cell}
+ if (IsOptionalOperandPresent(operatorPtr->inputs[16]))
+ {
+ params.m_ProjectionWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[16]].get(),
+ inputTensorInfo).first;
+ }
+ // Projection bias tensor of size {n_output}
+ if (IsOptionalOperandPresent(operatorPtr->inputs[17]))
+ {
+ params.m_ProjectionBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[17]].get(),
+ inputTensorInfo).first;
+ }
+
+ // These state tensors are defined as variable tensors, and will be modified by this op.
+ armnn::TensorInfo outputStateInInfo = ToTensorInfo(subgraphPtr->tensors[operatorPtr->inputs[18]].get());
+ m_ConstantsToBeCreated.push_back(operatorPtr->inputs[18]);
+ armnn::TensorInfo cellStateInInfo = ToTensorInfo(subgraphPtr->tensors[operatorPtr->inputs[19]].get());
+ m_ConstantsToBeCreated.push_back(operatorPtr->inputs[19]);
+
+ // Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
+ if (inputs.size() >= 21 && IsOptionalOperandPresent(operatorPtr->inputs[20]))
+ {
+ params.m_InputLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[20]].get(),
+ inputTensorInfo).first;
+ }
+
+ if (inputs.size() >= 22 && IsOptionalOperandPresent(operatorPtr->inputs[21]))
+ {
+ params.m_ForgetLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[21]].get(),
+ inputTensorInfo).first;
+ }
+
+ if (inputs.size() >= 23 && IsOptionalOperandPresent(operatorPtr->inputs[22]))
+ {
+ params.m_CellLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[22]].get(),
+ inputTensorInfo).first;
+ }
+
+ if (inputs.size() >= 24 && IsOptionalOperandPresent(operatorPtr->inputs[23]))
+ {
+ params.m_OutputLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[23]].get(),
+ inputTensorInfo).first;
+ }
+
+ // set the layer descriptor
+ armnn::UnidirectionalSequenceLstmDescriptor desc;
+ desc.m_ActivationFunc = nodeParams->fused_activation_function;
+ desc.m_ClippingThresCell = nodeParams->cell_clip;
+ desc.m_ClippingThresProj = nodeParams->proj_clip;
+ desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr
+ || params.m_RecurrentToInputWeights == nullptr
+ || params.m_InputGateBias == nullptr);
+ desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr || params.m_CellToOutputWeights != nullptr);
+ desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
+ desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr
+ || params.m_ForgetLayerNormWeights != nullptr
+ || params.m_CellLayerNormWeights != nullptr
+ || params.m_OutputLayerNormWeights != nullptr);
+ desc.m_TimeMajor = nodeParams->time_major;
+
+ if (desc.m_LayerNormEnabled)
+ {
+ auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(),
+ inputTensorInfo).first;
+ auto inputIntermediateTensorInfo = inputIntermediate->GetInfo();
+ desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale();
+
+ auto forgetIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[1]].get(),
+ inputTensorInfo).first;
+ auto forgetIntermediateTensorInfo = forgetIntermediate->GetInfo();
+ desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale();
+
+ auto cellIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[2]].get(),
+ inputTensorInfo).first;
+ auto cellIntermediateTensorInfo = cellIntermediate->GetInfo();
+ desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale();
+
+ auto outputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[3]].get(),
+ inputTensorInfo).first;
+ auto outputIntermediateTensorInfo = outputIntermediate->GetInfo();
+ desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale();
+ }
+ else
+ {
+ float defaultIntermediate = std::pow(2, -12);
+ desc.m_InputIntermediateScale = defaultIntermediate;
+ desc.m_ForgetIntermediateScale = defaultIntermediate;
+ desc.m_CellIntermediateScale = defaultIntermediate;
+ desc.m_OutputIntermediateScale = defaultIntermediate;
+ }
+
+ auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
+ inputTensorInfo).first;
+
+ desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
+ desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
+
+ unsigned int batchSize = inputTensorInfo.GetShape()[0];
+ unsigned int outputSize = outputTensorInfo.GetShape()[2];
+ unsigned int numUnits = cellStateInInfo.GetShape()[1];
+
+ armnn::DataType dataType = inputTensorInfo.GetDataType();
+ float qScale = inputTensorInfo.GetQuantizationScale();
+ float qOffset = inputTensorInfo.GetQuantizationOffset();
+
+ armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 3}, dataType, qScale, qOffset);
+ if (!desc.m_CifgEnabled)
+ {
+ scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
+ }
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits},
+ cellStateInInfo.GetDataType(),
+ cellStateInInfo.GetQuantizationScale(),
+ cellStateInInfo.GetQuantizationOffset());
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+
+ armnn::LstmInputParamsInfo paramsInfo;
+ paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
+ paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
+ paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
+ paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
+ paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
+ paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
+ paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
+ paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
+ paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
+
+ if (!desc.m_CifgEnabled)
+ {
+ paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
+ paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
+ if (params.m_CellToInputWeights != nullptr)
+ {
+ paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
+ }
+ paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
+ }
+
+ if (desc.m_ProjectionEnabled)
+ {
+ paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
+ if (params.m_ProjectionBias != nullptr)
+ {
+ paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
+ }
+ }
+
+ if (desc.m_PeepholeEnabled)
+ {
+ paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
+ paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
+ }
+
+ if (desc.m_LayerNormEnabled)
+ {
+ if(!desc.m_CifgEnabled)
+ {
+ paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
+ }
+ paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
+ paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
+ paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
+ }
+
+ auto layerName = fmt::format("UnidirectionalSequenceLSTM:{}:{}", subgraphIndex, operatorIndex);
+ armnn::IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(desc, params);
+ ARMNN_ASSERT(layer != nullptr);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ auto inputTensorIndexes = AsUnsignedVector({operatorPtr->inputs[0],
+ operatorPtr->inputs[18],
+ operatorPtr->inputs[19]});
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0],
+ inputTensorIndexes[1],
+ inputTensorIndexes[2]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
+
+ unsigned int tensorIndex = outputTensorIndexes[0];
+ armnn::IOutputSlot* slot = &(layer->GetOutputSlot(2));
+ RegisterProducerOfTensor(subgraphIndex, tensorIndex, slot);
+}
+
void TfLiteParserImpl::ParseUnpack(size_t subgraphIndex, size_t operatorIndex)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
@@ -4222,11 +4556,11 @@ void TfLiteParserImpl::SetupOutputLayers(size_t subgraphIndex)
}
}
-void TfLiteParserImpl::SetupConstantLayers(size_t subgraphIndex)
+void TfLiteParserImpl::SetupConstantLayers(size_t subgraph)
{
- CHECK_SUBGRAPH(m_Model, subgraphIndex);
+ CHECK_SUBGRAPH(m_Model, subgraph);
- const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
+ const auto & subgraphPtr = m_Model->subgraphs[subgraph];
for (unsigned int subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex)
{
for (unsigned int tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex)
@@ -4236,10 +4570,42 @@ void TfLiteParserImpl::SetupConstantLayers(size_t subgraphIndex)
{
TensorRawPtr tensorPtr = subgraphPtr->tensors[tensorIndex].get();
- if(IsConstTensor(tensorPtr))
+ if (IsConstTensor(tensorPtr))
{
armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
- auto tensorAndData = CreateConstTensorNonPermuted(tensorPtr, tensorInfo);
+ armnn::DataType dataType = tensorInfo.GetDataType();
+
+ if (std::find(m_ConstantsToDequantize.begin(), m_ConstantsToDequantize.end(), tensorPtr->buffer)
+ != m_ConstantsToDequantize.end())
+ {
+ dataType = DataType::Float32;
+ }
+ auto tensorAndData = CreateConstTensorNonPermuted(tensorPtr, tensorInfo, dataType);
+
+ std::string layerName = fmt::format("Constant:{}", tensorPtr->name);
+ IConnectableLayer *layer = m_Network->AddConstantLayer(tensorAndData.first, layerName.c_str());
+
+ layer->GetOutputSlot(0).SetTensorInfo(tensorAndData.first.GetInfo());
+ RegisterOutputSlots(subgraphIndex,
+ VIRTUAL_OPERATOR_ID,
+ layer,
+ { tensorIndex });
+ }
+ else if (ShouldConstantTensorBeCreated(tensorIndex))
+ {
+ armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
+ armnn::DataType dataType = tensorInfo.GetDataType();
+
+ if (std::find(m_ConstantsToDequantize.begin(), m_ConstantsToDequantize.end(), tensorPtr->buffer)
+ != m_ConstantsToDequantize.end())
+ {
+ dataType = DataType::Float32;
+ }
+ // Make sure isConstant flag is set.
+ tensorInfo.SetConstant();
+ tensorInfo.SetDataType(dataType);
+
+ auto tensorAndData = ConstTensor(tensorInfo, std::vector<uint8_t>(tensorInfo.GetNumBytes()));
std::string layerName = fmt::format("Constant:{}", tensorPtr->name);
IConnectableLayer* layer = m_Network->AddConstantLayer(tensorAndData, layerName.c_str());
@@ -4248,7 +4614,7 @@ void TfLiteParserImpl::SetupConstantLayers(size_t subgraphIndex)
RegisterOutputSlots(subgraphIndex,
VIRTUAL_OPERATOR_ID,
layer,
- { tensorIndex });
+ {tensorIndex});
}
else
{
@@ -4286,6 +4652,13 @@ TfLiteParserImpl::CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr b
return std::make_pair(constData.first, std::move(storage));
}
+bool TfLiteParserImpl::ShouldConstantTensorBeCreated(unsigned int tensorIndex)
+{
+ // If the TensorIndex appears in the list of ConstantsToBeCreated then return true
+ return (std::find(m_ConstantsToBeCreated.begin(), m_ConstantsToBeCreated.end(), tensorIndex)
+ != m_ConstantsToBeCreated.end());
+}
+
bool TfLiteParserImpl::IsConstTensor(TensorRawPtr tensorPtr)
{
CHECK_TENSOR_PTR(tensorPtr);
@@ -4364,6 +4737,53 @@ armnn::ConstTensor TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr t
return ConstTensor(tensorInfo, bufferPtr->data.data());
}
+std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
+TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
+ armnn::TensorInfo& tensorInfo,
+ armnn::DataType inputDataType)
+{
+ CHECK_TENSOR_PTR(tensorPtr);
+ auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
+ CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
+
+ // Make sure isConstant flag is set.
+ tensorInfo.SetConstant();
+
+ if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
+ {
+ TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
+ std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo);
+ return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data));
+ }
+ else
+ {
+ return std::make_pair(ConstTensor(tensorInfo, bufferPtr->data.data()), std::unique_ptr<float[]>());
+ }
+}
+
+std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
+TfLiteParserImpl::CreateConstTensorPtr(TensorRawPtr tensorPtr, armnn::TensorInfo& inputTensorInfo)
+{
+ CHECK_TENSOR_PTR(tensorPtr);
+ armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
+ auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
+ CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
+
+ // Make sure isConstant flag is set.
+ tensorInfo.SetConstant();
+
+ if (inputTensorInfo.GetDataType() == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
+ {
+ TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
+ std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo);
+ return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data));
+ }
+ else
+ {
+ return std::make_pair(new ConstTensor(tensorInfo, bufferPtr->data.data()), std::unique_ptr<float[]>());
+ }
+}
+
BindingPointInfo TfLiteParserImpl::GetNetworkInputBindingInfo(size_t subgraphId,
const std::string& name) const
{
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 474393cbe6..8c9674a5a6 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -183,6 +183,7 @@ private:
void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
+ void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
@@ -234,13 +235,19 @@ private:
std::unique_ptr<int32_t[]> m_Int32Data;
};
+ bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
bool IsConstTensor(TensorRawPtr tensorPtr);
armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo);
+
std::pair<armnn::ConstTensor, SupportedDataStorage>
CreateConstTensorPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+ std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
+ CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
+ armnn::TensorInfo& tensorInfo,
+ armnn::DataType inputDataType);
template<typename T>
std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
@@ -248,6 +255,9 @@ private:
TfLiteParserImpl::TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+ std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
+ CreateConstTensorPtr(TensorRawPtr tensorPtr,
+ armnn::TensorInfo& inputTensorInfo);
// Settings for configuring the TfLiteParser
armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
@@ -274,9 +284,12 @@ private:
/// The first index is the subgraph ID, the second index is the tensor ID
std::vector<TensorConnections> m_SubgraphConnections;
- /// This is used in case that the model does not speciry the output.
+ /// This is used in case that the model does not specify the output.
/// The shape can be calculated from the options.
std::vector<std::vector<unsigned int>> m_OverridenOutputShapes;
+
+ std::vector<unsigned int> m_ConstantsToDequantize;
+ std::vector<unsigned int> m_ConstantsToBeCreated;
};
}
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp
index c25e62bb00..45c4a43519 100644
--- a/src/armnnTfLiteParser/test/Conv2D.cpp
+++ b/src/armnnTfLiteParser/test/Conv2D.cpp
@@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D")
struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
{
- explicit Conv2DWithBiasesFixture(const std::string & inputShape,
- const std::string & outputShape,
- const std::string & filterShape,
- const std::string & filterData,
- const std::string & biasShape,
- const std::string & biasData,
- const std::string & strides,
- const std::string & activation="NONE",
- const std::string & filterScale="1.0",
- const std::string & filterZeroPoint="0",
- const std::string & outputScale="2.0",
- const std::string & outputZeroPoint="0")
+ explicit Conv2DWithBiasesFixture(const std::string& inputShape,
+ const std::string& outputShape,
+ const std::string& filterShape,
+ const std::string& filterData,
+ const std::string& biasShape,
+ const std::string& biasData,
+ const std::string& strides,
+ const std::string& activation="NONE",
+ const std::string& filterScale="1.0",
+ const std::string& filterZeroPoint="0",
+ const std::string& outputScale="2.0",
+ const std::string& outputZeroPoint="0",
+ const std::string& dataType = "UINT8",
+ const std::string& filterDataType = "UINT8",
+ const std::string& biasDataType = "INT32")
{
m_JsonString = R"(
{
@@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"( ,
- "type": "UINT8",
+ "type": )" + filterDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D")
});
}
+struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture
+{
+ Conv2FloatWithInt8WeightsAndBiasesFixture()
+ : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
+ "[ 1, 2, 2, 1 ]", // outputShape
+ "[ 1, 2, 2, 1 ]", // filterShape
+ "[ 2,1, 0,6 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // biasData
+ "1", // stride w and h
+ "NONE", // activation
+ "1.0", // filterScale
+ "0", // filterZeroPoint
+ "2.0", // outputScale
+ "0", // outputZeroPoint
+ "FLOAT32", // dataType
+ "INT8", // filterDataType
+ "INT8") // biasDataType
+ {}
+};
+
+TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture")
+{
+ RunTest<4, armnn::DataType::Float32>(
+ 0,
+ {
+ 1, 2,
+ 3, 4,
+ },
+ {
+ (1*2 + 2*1 + 3*0 + 4*6 + 10),
+ (2*2 + 0*1 + 4*0 + 0*6 + 10),
+ (3*2 + 4*1 + 0*0 + 0*6 + 10),
+ (4*2 + 0*1 + 0*0 + 0*6 + 10)
+ });
+}
+
}
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index fc000bf95b..108b878e20 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
const std::string& filterShape,
const std::string& filterData,
const std::string biasShape = "",
- const std::string biasData = "")
+ const std::string biasData = "",
+ const std::string dataType = "UINT8",
+ const std::string weightsDataType = "UINT8",
+ const std::string biasDataType = "INT32")
{
std::string inputTensors = "[ 0, 2 ]";
std::string biasTensor = "";
@@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
biasTensor = R"(
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"(,
- "type": "UINT8",
+ "type": )" + weightsDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo
{{"output", { 20 }}});
}
+struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
+{
+ FullyConnectedWeightsBiasFloat()
+ : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
+ "[ 1, 1 ]", // outputShape
+ "[ 1, 4 ]", // filterShape
+ "[ 2, 3, 4, 5 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // filterShape
+ "FLOAT32", // input and output dataType
+ "INT8", // weights dataType
+ "FLOAT32") // bias dataType
+ {}
+};
+
+TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ { 10, 20, 30, 40 },
+ { 400 });
+}
+
}