aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/GraphTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/GraphTests.cpp')
-rw-r--r--src/armnn/test/GraphTests.cpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/src/armnn/test/GraphTests.cpp b/src/armnn/test/GraphTests.cpp
index 6b3e611017..d3dd499850 100644
--- a/src/armnn/test/GraphTests.cpp
+++ b/src/armnn/test/GraphTests.cpp
@@ -614,4 +614,29 @@ TEST_CASE("CheckGraphConstTensorSharing")
CHECK(*sharedWeightPtr == 1);
}
+TEST_CASE("IConnectableLayerConstantTensorsByRef")
+{
+ using namespace armnn;
+ INetworkPtr net(INetwork::Create());
+
+ std::vector<uint8_t> falseData = {3};
+ ConstTensor falseTensor(TensorInfo({1}, DataType::Boolean, 0.0f, 0, true), falseData);
+ IConnectableLayer* constLayer = net->AddConstantLayer(falseTensor, "const");
+ constLayer->GetOutputSlot(0).SetTensorInfo(TensorInfo({1, 1, 1, 1}, DataType::Boolean));
+
+ const TensorInfo& constInfo = constLayer->GetOutputSlot(0).GetTensorInfo();
+
+ const void* weightData = constLayer->GetConstantTensorsByRef()[0].get()->GetConstTensor<void>();
+ auto weightValue = reinterpret_cast<const uint8_t*>(weightData);
+ CHECK(weightValue[0] == 3);
+ TensorInfo weightsInfo = constInfo;
+ ConstTensor weights(weightsInfo, weightData);
+ DepthwiseConvolution2dDescriptor desc;
+ const auto depthwiseLayer = net->AddDepthwiseConvolution2dLayer(desc, weights, EmptyOptional(), "Depthwise");
+
+ const void* resultData = depthwiseLayer->GetConstantTensorsByRef()[0].get()->GetConstTensor<void>();
+ auto resultValue = reinterpret_cast<const uint8_t*>(resultData);
+ CHECK(resultValue[0] == 3);
+}
+
}