diff options
author | Jim Flynn <jim.flynn@arm.com> | 2019-03-19 17:22:29 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2019-03-21 16:09:19 +0000 |
commit | 11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch) | |
tree | f4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnnSerializer/ArmnnSchema.fbs | |
parent | db059fd50f9afb398b8b12cd4592323fc8f60d7f (diff) | |
download | armnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz |
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests
Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c
Signed-off-by: Nina Drozd <nina.drozd@arm.com>
Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnSerializer/ArmnnSchema.fbs')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index a11eeadf12..2cceaae031 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -115,7 +115,8 @@ enum LayerType : uint { Merger = 30, L2Normalization = 31, Splitter = 32, - DetectionPostProcess = 33 + DetectionPostProcess = 33, + Lstm = 34 } // Base layer table to be used as part of other layers @@ -475,6 +476,44 @@ table DetectionPostProcessDescriptor { scaleH:float; } +table LstmInputParams { + inputToForgetWeights:ConstTensor; + inputToCellWeights:ConstTensor; + inputToOutputWeights:ConstTensor; + recurrentToForgetWeights:ConstTensor; + recurrentToCellWeights:ConstTensor; + recurrentToOutputWeights:ConstTensor; + forgetGateBias:ConstTensor; + cellBias:ConstTensor; + outputGateBias:ConstTensor; + + inputToInputWeights:ConstTensor; + recurrentToInputWeights:ConstTensor; + cellToInputWeights:ConstTensor; + inputGateBias:ConstTensor; + + projectionWeights:ConstTensor; + projectionBias:ConstTensor; + + cellToForgetWeights:ConstTensor; + cellToOutputWeights:ConstTensor; +} + +table LstmDescriptor { + activationFunc:uint; + clippingThresCell:float; + clippingThresProj:float; + cifgEnabled:bool = true; + peepholeEnabled:bool = false; + projectionEnabled:bool = false; +} + +table LstmLayer { + base:LayerBase; + descriptor:LstmDescriptor; + inputParams:LstmInputParams; +} + union Layer { ActivationLayer, AdditionLayer, @@ -509,7 +548,8 @@ union Layer { MergerLayer, L2NormalizationLayer, SplitterLayer, - DetectionPostProcessLayer + DetectionPostProcessLayer, + LstmLayer } table AnyLayer { |