diff options
Diffstat (limited to 'driver_library/python/src')
3 files changed, 46 insertions, 31 deletions
diff --git a/driver_library/python/src/ethosu_driver/_utilities/driver_utilities.py b/driver_library/python/src/ethosu_driver/_utilities/driver_utilities.py index fcea91f..ca39751 100644 --- a/driver_library/python/src/ethosu_driver/_utilities/driver_utilities.py +++ b/driver_library/python/src/ethosu_driver/_utilities/driver_utilities.py @@ -30,8 +30,7 @@ def load_model(device: Device, model: str) -> Network: `Network`: Return the object that represent the neural __network file descriptor received from the Ethos-U device. """ logging.info("Creating network") - network_buffer = Buffer(device, model) - return Network(device, network_buffer) + return Network(device, model) def populate_buffers(input_data: List[bytearray], buffers: List[Buffer]): diff --git a/driver_library/python/src/ethosu_driver/swig/driver.i b/driver_library/python/src/ethosu_driver/swig/driver.i index 3e4e384..6e0ad25 100644 --- a/driver_library/python/src/ethosu_driver/swig/driver.i +++ b/driver_library/python/src/ethosu_driver/swig/driver.i @@ -293,12 +293,12 @@ public: buffer: data to be copied to the mapped memory. ") from_buffer; - %mutable_buffer(char* buffer, size_t size); + %buffer_in(char* buffer, size_t size, BUFFER_FLAG_RW); void from_buffer(char* buffer, size_t size) { char* data = $self->data(); std::memcpy(data, buffer, size); } - %clear_mutable_buffer(char* buffer, size_t size); + %clear_buffer_in(char* buffer, size_t size); } %feature("docstring", @@ -329,15 +329,6 @@ public: %feature("docstring", " - Returns associated memory buffer. - - Returns: - `Buffer`: buffer object used during initialisation. - ") getBuffer; - std::shared_ptr<Buffer> getBuffer(); - - %feature("docstring", - " Returns saved sizes of the neural network model input feature maps. Returns: @@ -374,21 +365,41 @@ public: }; %extend Network { - Network(const Device &device, std::shared_ptr<Buffer> &buffer) + + Network(const Device &device, const std::string& filename) { - if(buffer == nullptr){ - throw EthosU::Exception(std::string("Failed to create the network, buffer is nullptr.").c_str()); + std::ifstream stream(filename, std::ios::binary); + if (!stream.is_open()) { + throw EthosU::Exception(std::string("Failed to open file: ").append(filename).c_str()); } - auto network = new EthosU::Network(device, buffer); - return network; + + stream.seekg(0, std::ios_base::end); + size_t size = stream.tellg(); + stream.seekg(0, std::ios_base::beg); + + std::unique_ptr<unsigned char[]> buffer = std::make_unique<unsigned char[]>(size); + stream.read(reinterpret_cast<char*>(buffer.get()), size); + return new EthosU::Network(device, buffer.get(), size); } -} -%extend Network { + %buffer_in(const unsigned char* networkData, size_t networkSize, BUFFER_FLAG_RO); + Network(const Device &device, const unsigned char* networkData, size_t networkSize) + { + if(networkData == nullptr){ + throw EthosU::Exception(std::string("Failed to create the network, networkData is nullptr.").c_str()); + } + + if(networkSize == 0U){ + throw EthosU::Exception(std::string("Failed to create the network, networkSize is zero.").c_str()); + } + + return new EthosU::Network(device, networkData, networkSize); + } + %clear_buffer_in(const unsigned char* networkData, size_t networkSize); + Network(const Device &device, const unsigned int index) { - auto network = new EthosU::Network(device, index); - return network; + return new EthosU::Network(device, index); } } diff --git a/driver_library/python/src/ethosu_driver/swig/typemaps/buffer.i b/driver_library/python/src/ethosu_driver/swig/typemaps/buffer.i index 13b7909..bb4627c 100644 --- a/driver_library/python/src/ethosu_driver/swig/typemaps/buffer.i +++ b/driver_library/python/src/ethosu_driver/swig/typemaps/buffer.i @@ -1,19 +1,25 @@ // -// SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +// SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> // SPDX-License-Identifier: Apache-2.0 // -%define %mutable_buffer(TYPEMAP, SIZE) + +%define BUFFER_FLAG_RO 0 %enddef +%define BUFFER_FLAG_RW PyBUF_WRITABLE %enddef + +%define %buffer_in(TYPEMAP, SIZE, FLAG) %typemap(in) (TYPEMAP, SIZE) { - int res; void *buf = 0; size_t size = 0; Py_buffer view; - res = PyObject_GetBuffer($input, &view, PyBUF_WRITABLE); - buf = view.buf; - size = view.len; - PyBuffer_Release(&view); + + int res = PyObject_GetBuffer($input, &view, FLAG); if (res < 0) { PyErr_Clear(); %argument_fail(res, "(TYPEMAP, SIZE)", $symname, $argnum); } + + void *buf = view.buf; + size_t size = view.len; + PyBuffer_Release(&view); + $1 = ($1_ltype) buf; $2 = ($2_ltype) size; } @@ -23,12 +29,11 @@ } %enddef -%define %clear_mutable_buffer(TYPEMAP, SIZE) +%define %clear_buffer_in(TYPEMAP, SIZE) %typemap(in) (TYPEMAP, SIZE); %typemap(typecheck) (TYPEMAP, SIZE); %enddef - %define %driver_buffer_out %typemap(out) (char*) { auto size = arg1->size(); |