diff options
Diffstat (limited to 'src/backends/tosaCommon/test/OneToOneMappingTests.cpp')
-rw-r--r-- | src/backends/tosaCommon/test/OneToOneMappingTests.cpp | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/src/backends/tosaCommon/test/OneToOneMappingTests.cpp b/src/backends/tosaCommon/test/OneToOneMappingTests.cpp index 0d19a328d6..2b0c1e55c7 100644 --- a/src/backends/tosaCommon/test/OneToOneMappingTests.cpp +++ b/src/backends/tosaCommon/test/OneToOneMappingTests.cpp @@ -438,6 +438,108 @@ TEST_CASE("GetTosaMappingFromLayer_SliceLayer") } +TEST_CASE("GetTosaMapping_TransposeConv2dLayer") +{ + const TensorInfo inputInfo ({ 1, 7, 7, 1 }, DataType::Float32); + const TensorInfo outputInfo({ 1, 9, 9, 1 }, DataType::Float32); + const TensorInfo weightsInfo({ 1, 3, 3, 1 }, DataType::Float32, 0.0f, 0, true); + const TensorInfo biasesInfo ({ 1 }, DataType::Float32, 0.0f, 0, true); + + TransposeConvolution2dDescriptor descriptor; + descriptor.m_PadLeft = 1; + descriptor.m_PadRight = 1; + descriptor.m_PadTop = 1; + descriptor.m_PadBottom = 1; + descriptor.m_StrideX = 1; + descriptor.m_StrideY = 1; + descriptor.m_BiasEnabled = true; + descriptor.m_DataLayout = DataLayout::NHWC; + + TosaSerializationBasicBlock* basicBlock = GetTosaMapping(nullptr, + LayerType::TransposeConvolution2d, + {&inputInfo, &weightsInfo, &biasesInfo}, + {&outputInfo}, + descriptor); + + CHECK(basicBlock->GetInputs().size() == 3); + CHECK(basicBlock->GetOutputs().size() == 1); + CHECK(basicBlock->GetOperators().size() == 3); + CHECK(basicBlock->GetTensors().size() == 4); + + CHECK(basicBlock->GetInputs()[0].find("input0_") != std::string::npos); + CHECK(basicBlock->GetInputs()[1].find("constant_") != std::string::npos); + CHECK(basicBlock->GetInputs()[2].find("constant_") != std::string::npos); + CHECK(basicBlock->GetOutputs()[0].find("output0_") != std::string::npos); + + VerifyTosaAttribute(descriptor, + basicBlock->GetOperators().at(2)->GetAttribute(), + {}, + {}, + LayerType::TransposeConvolution2d); +} + +TEST_CASE("GetTosaMappingFromLayer_TransposeConv2dLayer") +{ + IRuntime::CreationOptions options; + IRuntimePtr runtime(IRuntime::Create(options)); + + // Builds up the structure of the network. + INetworkPtr net(INetwork::Create()); + + const TensorInfo inputInfo ({ 1, 7, 7, 1 }, DataType::Float32); + const TensorInfo outputInfo({ 1, 9, 9, 1 }, DataType::Float32); + const TensorInfo weightsInfo({ 1, 3, 3, 1 }, DataType::Float32, 0.0f, 0, true); + const TensorInfo biasesInfo ({ 1 }, DataType::Float32, 0.0f, 0, true); + + std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements()); + ConstTensor weights(weightsInfo, weightsData); + + std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements()); + ConstTensor biases(biasesInfo, biasesData); + + TransposeConvolution2dDescriptor descriptor; + descriptor.m_PadLeft = 1; + descriptor.m_PadRight = 1; + descriptor.m_PadTop = 1; + descriptor.m_PadBottom = 1; + descriptor.m_StrideX = 1; + descriptor.m_StrideY = 1; + descriptor.m_BiasEnabled = true; + descriptor.m_DataLayout = DataLayout::NHWC; + + IConnectableLayer* const inputLayer = net->AddInputLayer(0); + IConnectableLayer* const convLayer = + net->AddTransposeConvolution2dLayer(descriptor, + weights, + Optional<ConstTensor>(biases), + "transposeConvolution2d"); + IConnectableLayer* const outputLayer = net->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0)); + convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + TosaSerializationBasicBlock* basicBlock = GetTosaMappingFromLayer(PolymorphicDowncast<Layer*>(convLayer)); + + CHECK(basicBlock->GetInputs().size() == 3); + CHECK(basicBlock->GetOutputs().size() == 1); + CHECK(basicBlock->GetOperators().size() == 3); + CHECK(basicBlock->GetTensors().size() == 4); + + CHECK(basicBlock->GetInputs()[0].find("input0_") != std::string::npos); + CHECK(basicBlock->GetInputs()[1].find("constant_") != std::string::npos); + CHECK(basicBlock->GetInputs()[2].find("constant_") != std::string::npos); + CHECK(basicBlock->GetOutputs()[0].find("output0_") != std::string::npos); + + VerifyTosaAttribute(descriptor, + basicBlock->GetOperators().at(2)->GetAttribute(), + {}, + {}, + LayerType::TransposeConvolution2d); +} + TEST_CASE("GetTosaMapping_Unimplemented") { TosaSerializationBasicBlock* basicBlock = |