Skip to content

Improve parsing of non-nn.Sequential PyTorch models

Javier Duarte requested to merge github/fork/vloncar/torch_input_map_fix into main

Created by: vloncar

Description

In case of skipped layers, like Flatten or Dropout, PyTorch converter will incorrectly parse the model inputs, we need to create an input map similar to how Keras handles it. This was the case in #839. Additionally, as observed in #838, parsing of BN weights was broken. These fixes are cherrypicked from my development branch for parsing GNNs, not fully tested standalone, so I'm making this a draft PR for now before I add proper tests.

Type of change

  • Bug fix (non-breaking change that fixes an issue)

Tests

Currently lacking. Will add something along the lines of code shared in #838 and #839

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works. <-- Will do in a follow-up commit

Merge request reports

Loading