aboutsummaryrefslogtreecommitdiff
path: root/driver_library/src/ethosu.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'driver_library/src/ethosu.cpp')
-rw-r--r--driver_library/src/ethosu.cpp121
1 files changed, 28 insertions, 93 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index f792399..01631b3 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -16,8 +16,6 @@
* limitations under the License.
*/
-#include "autogen/tflite_schema.hpp"
-
#include <ethosu.hpp>
#include <uapi/ethosu.h>
@@ -92,61 +90,6 @@ __attribute__((weak)) int emunmap(void *addr, size_t length) {
} // namespace EthosU
-/****************************************************************************
- * TFL micro helpers
- ****************************************************************************/
-namespace {
-size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape) {
- size_t size = 1;
-
- if (shape == nullptr) {
- throw EthosU::Exception("getShapeSize(): nullptr arg");
- }
-
- for (auto it = shape->begin(); it != shape->end(); ++it) {
- size *= *it;
- }
-
- return size;
-}
-
-size_t getTensorTypeSize(const enum tflite::TensorType type) {
- switch (type) {
- case tflite::TensorType::TensorType_UINT8:
- case tflite::TensorType::TensorType_INT8:
- return 1;
- case tflite::TensorType::TensorType_INT16:
- return 2;
- case tflite::TensorType::TensorType_INT32:
- case tflite::TensorType::TensorType_FLOAT32:
- return 4;
- default:
- throw EthosU::Exception("Unsupported tensor type");
- }
-}
-
-vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap) {
- vector<size_t> dims;
-
- if (subgraph == nullptr || tensorMap == nullptr) {
- throw EthosU::Exception("getSubGraphDims(): nullptr arg(s)");
- }
-
- for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
- auto tensor = subgraph->tensors()->Get(*index);
- size_t size = getShapeSize(tensor->shape());
- size *= getTensorTypeSize(tensor->type());
-
- if (size > 0) {
- dims.push_back(size);
- }
- }
-
- return dims;
-}
-
-} // namespace
-
namespace EthosU {
/****************************************************************************
@@ -247,8 +190,13 @@ Buffer::Buffer(const Device &device, const size_t capacity) : fd(-1), dataPtr(nu
}
Buffer::~Buffer() {
- emunmap(dataPtr, dataCapacity);
- eclose(fd);
+ try {
+ emunmap(dataPtr, dataCapacity);
+ } catch (std::exception &e) {
+ try {
+ eclose(fd);
+ } catch (...) { std::throw_with_nested(e); }
+ }
}
size_t Buffer::capacity() const {
@@ -296,12 +244,12 @@ Network::Network(const Device &device, shared_ptr<Buffer> &buffer) : fd(-1), buf
uapi.type = ETHOSU_UAPI_NETWORK_BUFFER;
uapi.fd = buffer->getFd();
fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
-
try {
- parseModel(buffer->data());
- } catch (...) {
- eclose(fd);
- throw;
+ collectNetworkInfo();
+ } catch (std::exception &e) {
+ try {
+ eclose(fd);
+ } catch (...) { std::throw_with_nested(e); }
}
}
@@ -311,21 +259,25 @@ Network::Network(const Device &device, const unsigned index) : fd(-1) {
uapi.type = ETHOSU_UAPI_NETWORK_INDEX;
uapi.index = index;
fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
-
try {
- ethosu_uapi_network_info info;
- ioctl(ETHOSU_IOCTL_NETWORK_INFO, static_cast<void *>(&info));
+ collectNetworkInfo();
+ } catch (std::exception &e) {
+ try {
+ eclose(fd);
+ } catch (...) { std::throw_with_nested(e); }
+ }
+}
- for (uint32_t i = 0; i < info.ifm_count; i++) {
- ifmDims.push_back(info.ifm_size[i]);
- }
+void Network::collectNetworkInfo() {
+ ethosu_uapi_network_info info;
+ ioctl(ETHOSU_IOCTL_NETWORK_INFO, static_cast<void *>(&info));
- for (uint32_t i = 0; i < info.ofm_count; i++) {
- ofmDims.push_back(info.ofm_size[i]);
- }
- } catch (...) {
- eclose(fd);
- throw;
+ for (uint32_t i = 0; i < info.ifm_count; i++) {
+ ifmDims.push_back(info.ifm_size[i]);
+ }
+
+ for (uint32_t i = 0; i < info.ofm_count; i++) {
+ ofmDims.push_back(info.ofm_size[i]);
}
}
@@ -369,23 +321,6 @@ size_t Network::getOfmSize() const {
return size;
}
-void Network::parseModel(const char *data) {
- // Create model handle
- const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data));
-
- if (model->subgraphs() == nullptr) {
- EthosU::Exception("Failed to get subgraphs: nullptr");
- }
-
- // Get input dimensions for first subgraph
- auto *subgraph = *model->subgraphs()->begin();
- ifmDims = getSubGraphDims(subgraph, subgraph->inputs());
-
- // Get output dimensions for last subgraph
- subgraph = *model->subgraphs()->rbegin();
- ofmDims = getSubGraphDims(subgraph, subgraph->outputs());
-}
-
/****************************************************************************
* Inference
****************************************************************************/