Support non contiguous flash ELF/UF2, fix problem reading strings near end of file (#68)

Co-authored-by: Jan Niehusmann <jan@gondor.com>
diff --git a/main.cpp b/main.cpp
index 9d867c0..d0c00dc 100644
--- a/main.cpp
+++ b/main.cpp
@@ -201,7 +201,7 @@
         }
         f--;
         assert(p >= f->first);
-        if (p > f->second.first) {
+        if (p >= f->second.first) {
             throw not_mapped_exception();
         }
         return std::make_pair(mapping(p - f->first, f->second.first - f->first), f->second.second);
@@ -701,7 +701,11 @@
 }
 
 struct memory_access {
-    virtual void read(uint32_t, uint8_t *buffer, uint size) = 0;
+    virtual void read(uint32_t p, uint8_t *buffer, uint size) {
+        read(p, buffer, size, false);
+    }
+
+    virtual void read(uint32_t, uint8_t *buffer, uint size, bool zero_fill) = 0;
 
     virtual bool is_device() { return false; }
 
@@ -728,10 +732,10 @@
     }
 
     // read a vector of types that have a raw_type_mapping
-    template <typename T> vector<T> read_vector(uint32_t addr, uint count) {
+    template <typename T> vector<T> read_vector(uint32_t addr, uint count, bool zero_fill = false) {
         assert(count);
         vector<typename raw_type_mapping<T>::access_type> buffer(count);
-        read(addr, (uint8_t *)buffer.data(), count * sizeof(typename raw_type_mapping<T>::access_type));
+        read(addr, (uint8_t *)buffer.data(), count * sizeof(typename raw_type_mapping<T>::access_type), zero_fill);
         vector<T> v;
         v.reserve(count);
         for(const auto &e : buffer) {
@@ -740,9 +744,9 @@
         return v;
     }
 
-    template <typename T> void read_into_vector(uint32_t addr, uint count, vector<T> &v) {
+    template <typename T> void read_into_vector(uint32_t addr, uint count, vector<T> &v, bool zero_fill = false) {
         vector<typename raw_type_mapping<T>::access_type> buffer(count);
-        if (count) read(addr, (uint8_t *)buffer.data(), count * sizeof(typename raw_type_mapping<T>::access_type));
+        if (count) read(addr, (uint8_t *)buffer.data(), count * sizeof(typename raw_type_mapping<T>::access_type), zero_fill);
         v.clear();
         v.reserve(count);
         for(const auto &e : buffer) {
@@ -786,7 +790,7 @@
         return FLASH_START;
     }
 
-    void read(uint32_t address, uint8_t *buffer, uint size) override {
+    void read(uint32_t address, uint8_t *buffer, uint size, __unused bool zero_fill) override {
         if (flash == get_memory_type(address)) {
             connection.exit_xip();
         }
@@ -854,13 +858,25 @@
         return binary_start;
     }
 
-    void read(uint32_t address, uint8_t *buffer, uint32_t size) override {
+    void read(uint32_t address, uint8_t *buffer, uint32_t size, bool zero_fill) override {
         while (size) {
-            auto result = rmap.get(address);
-            uint this_size = std::min(size, result.first.max_offset - result.first.offset);
-            assert( this_size);
-            fseek(file, result.second + result.first.offset, SEEK_SET);
-            fread(buffer, this_size, 1, file);
+            uint this_size;
+            try {
+                auto result = rmap.get(address);
+                this_size = std::min(size, result.first.max_offset - result.first.offset);
+                assert(this_size);
+                fseek(file, result.second + result.first.offset, SEEK_SET);
+                fread(buffer, this_size, 1, file);
+            } catch (not_mapped_exception &e) {
+                if (zero_fill) {
+                    // address is not in a range, so fill up to next range with zeros
+                    this_size = rmap.next(address) - address;
+                    this_size = std::min(this_size, size);
+                    memset(buffer, 0, this_size);
+                } else {
+                    throw e;
+                }
+            }
             buffer += this_size;
             address += this_size;
             size -= this_size;
@@ -883,12 +899,12 @@
 struct remapped_memory_access : public memory_access {
     remapped_memory_access(memory_access &wrap, range_map<uint32_t> rmap) : wrap(wrap), rmap(rmap) {}
 
-    void read(uint32_t address, uint8_t *buffer, uint size) override {
+    void read(uint32_t address, uint8_t *buffer, uint size, bool zero_fill) override {
         while (size) {
             auto result = get_remapped(address);
             uint this_size = std::min(size, result.first.max_offset - result.first.offset);
             assert( this_size);
-            wrap.read(result.second + result.first.offset, buffer, this_size);
+            wrap.read(result.second + result.first.offset, buffer, this_size, zero_fill);
             buffer += this_size;
             address += this_size;
             size -= this_size;
@@ -993,20 +1009,15 @@
 }
 
 string read_string(memory_access &access, uint32_t addr) {
-    // note this implementation is still wrong, it just tries a bit harder to not try to read off the end of the image (which causes
-    // an assertion failure)
-    uint max_length;
-    for(max_length = 8; max_length <= 1024; max_length *=2 ) {
-        auto v = access.read_vector<char>(addr, max_length);
-        uint length;
-        for (length = 0; length < max_length; length++) {
-            if (!v[length]) {
-                return string(v.data(), length);
-            }
+    const uint max_length = 512;
+    auto v = access.read_vector<char>(addr, max_length, true); // zero fill
+    uint length;
+    for (length = 0; length < max_length; length++) {
+        if (!v[length]) {
+            break;
         }
     }
-    return "<failed to read string>";
-}
+    return string(v.data(), length);}
 
 struct bi_visitor_base {
     void visit(memory_access& access, const binary_info_header& hdr) {
@@ -1780,7 +1791,7 @@
     return false;
 }
 
-vector<range> get_colaesced_ranges(file_memory_access &file_access) {
+vector<range> get_coalesced_ranges(file_memory_access &file_access) {
     auto rmap = file_access.get_rmap();
     auto ranges = rmap.ranges();
     std::sort(ranges.begin(), ranges.end(), [](const range& a, const range &b) {
@@ -1789,7 +1800,14 @@
     // coalesce all the contiguous ranges
     for(auto i = ranges.begin(); i < ranges.end(); ) {
         if (i != ranges.end() - 1) {
-            if (i->to == (i+1)->from) {
+            uint32_t erase_size;
+            // we want to coalesce flash sectors together (this ends up creating ranges that may have holes)
+            if( get_memory_type(i->from) == flash ) {
+                erase_size = FLASH_SECTOR_ERASE_SIZE;
+            } else {
+                erase_size = 1;
+            }
+            if (i->to / erase_size == (i+1)->from / erase_size) {
                 i->to = (i+1)->to;
                 i = ranges.erase(i+1) - 1;
                 continue;
@@ -1826,7 +1844,7 @@
             visitor.visit(access, hdr);
         }
     }
-    auto ranges = get_colaesced_ranges(file_access);
+    auto ranges = get_coalesced_ranges(file_access);
     for (auto mem_range : ranges) {
         enum memory_type t1 = get_memory_type(mem_range.from);
         enum memory_type t2 = get_memory_type(mem_range.to);
@@ -1854,14 +1872,15 @@
             bool ok = true;
             vector<uint8_t> file_buf;
             vector<uint8_t> device_buf;
-            for (uint32_t base = mem_range.from; base < mem_range.to && ok; ) {
+            for (uint32_t base = mem_range.from; base < mem_range.to && ok;) {
                 uint32_t this_batch = std::min(mem_range.to - base, batch_size);
                 if (type == flash) {
                     // we have to erase an entire page, so then fill with zeros
-                    range aligned_range(base & ~(FLASH_SECTOR_ERASE_SIZE - 1), (base & ~(FLASH_SECTOR_ERASE_SIZE - 1)) + FLASH_SECTOR_ERASE_SIZE);
+                    range aligned_range(base & ~(FLASH_SECTOR_ERASE_SIZE - 1),
+                                        (base & ~(FLASH_SECTOR_ERASE_SIZE - 1)) + FLASH_SECTOR_ERASE_SIZE);
                     range read_range(base, base + this_batch);
                     read_range.intersect(aligned_range);
-                    file_access.read_into_vector(read_range.from, read_range.to - read_range.from, file_buf);
+                    file_access.read_into_vector(read_range.from, read_range.to - read_range.from, file_buf, true); // zero fill to cope with holes
                     // zero padding up to FLASH_SECTOR_ERASE_SIZE
                     file_buf.insert(file_buf.begin(), read_range.from - aligned_range.from, 0);
                     file_buf.insert(file_buf.end(), aligned_range.to - read_range.to, 0);
@@ -1869,14 +1888,14 @@
 
                     bool skip = false;
                     if (settings.load.update) {
-                      vector<uint8_t> read_device_buf;
-                      raw_access.read_into_vector(aligned_range.from, batch_size, read_device_buf);
-                      skip = file_buf == read_device_buf;
+                        vector<uint8_t> read_device_buf;
+                        raw_access.read_into_vector(aligned_range.from, batch_size, read_device_buf);
+                        skip = file_buf == read_device_buf;
                     }
                     if (!skip) {
-                      con.exit_xip();
-                      con.flash_erase(aligned_range.from, FLASH_SECTOR_ERASE_SIZE);
-                      raw_access.write_vector(aligned_range.from, file_buf);
+                        con.exit_xip();
+                        con.flash_erase(aligned_range.from, FLASH_SECTOR_ERASE_SIZE);
+                        raw_access.write_vector(aligned_range.from, file_buf);
                     }
                     base = read_range.to; // about to add batch_size
                 } else {
@@ -1887,6 +1906,9 @@
                 bar.progress(base - mem_range.from, mem_range.to - mem_range.from);
             }
         }
+    }
+    for (auto mem_range : ranges) {
+        enum memory_type type = get_memory_type(mem_range.from);
         if (settings.load.verify) {
             bool ok = true;
             {
@@ -1897,7 +1919,10 @@
                 uint32_t pos = mem_range.from;
                 for (uint32_t base = mem_range.from; base < mem_range.to && ok; base += batch_size) {
                     uint32_t this_batch = std::min(mem_range.to - base, batch_size);
-                    file_access.read_into_vector(base, this_batch, file_buf);
+                    // note we pass zero_fill = true in case the file has holes, but this does
+                    // mean that the verification will fail if those holes are not filed with zeros
+                    // on the device
+                    file_access.read_into_vector(base, this_batch, file_buf, true);
                     raw_access.read_into_vector(base, this_batch, device_buf);
                     assert(file_buf.size() == device_buf.size());
                     for (uint i = 0; i < this_batch; i++) {
@@ -1940,7 +1965,7 @@
     auto file_access = get_file_memory_access();
     auto con = get_single_bootsel_device_connection(devices);
     picoboot_memory_access raw_access(con);
-    auto ranges = get_colaesced_ranges(file_access);
+    auto ranges = get_coalesced_ranges(file_access);
     if (settings.range_set) {
         range filter(settings.from, settings.to);
         for(auto& range : ranges) {
@@ -1966,7 +1991,10 @@
                     uint32_t batch_size = 1024;
                     for(uint32_t base = mem_range.from; base < mem_range.to && ok; base += batch_size) {
                         uint32_t this_batch = std::min(mem_range.to - base, batch_size);
-                        file_access.read_into_vector(base, this_batch, file_buf);
+                        // note we pass zero_fill = true in case the file has holes, but this does
+                        // mean that the verification will fail if those holes are not filed with zeros
+                        // on the device
+                        file_access.read_into_vector(base, this_batch, file_buf, true);
                         raw_access.read_into_vector(base, this_batch, device_buf);
                         assert(file_buf.size() == device_buf.size());
                         for(uint i=0;i<this_batch;i++) {