Skip to content

Add RNN support for Pytorch

Javier Duarte requested to merge github/fork/JanFSchulte/GRUv1 into main

Created by: JanFSchulte

Adds support for RNN layers (GRU, LSTM, RNN) to the pytorch parser.

Caveat: We currently lack implementation for getitem operations, so we can currently not return the hidden state after the calculations

Caveat 2: We currently only support a single recurrent layers, whereas multiple within the same RNN instance are supported by pytorch

Caveat 3: We currently don't support the passing of non-zero initial values for the hidden states to the RNN

So this implementation is slightly hacky at the moment, but might serve as a starting point for discussion, and can be used by interested parties if they can life with the current limitations.

Also, this contains parts of https://github.com/fastmachinelearning/hls4ml/pull/848 because I was inattentive.

Type of change

For a new feature or function, please create an issue first to discuss it with us before submitting a pull request.

Note: Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Tests

Added pytests to confirm that the layers work.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

Merge request reports