Support for parsing nested models
Created by: vloncar
Description
Currently, models we parse have to have a flat hierarchy while many more advanced models (like (V)AEs or GNNs) use nested* hierarchy. In this PR we add support for representing groups of layers and for parsing Keras models that have other models embedded inside. The parsed model will temporarily have a LayerGroup
"layer" that will then be expanded (via an optimizer) into a flat hierarchy. In the future we can consider not expanding and instead generating multiple IPs etc.
This will serve as a basis for the equivalent functionality of the PyTorch parser, that will be built on top of #723.
- Note that "nested" here doesn't mean custom. Subclassing
tf.keras.Model
is not supported (for example like this, since the Keras serialized format that we parse won't include information about that is needed. We may revisit this in the future once a few other building blocks are in place.
Type of change
-
New feature (non-breaking change which adds functionality)
Tests
I added test_keras_nested_model.py
with some tests of the functionality.
Checklist
-
I have read the guidelines for contributing. -
I have commented my code, particularly in hard-to-understand areas. -
I have made corresponding changes to the documentation. -
My changes generate no new warnings. -
I have installed and run pre-commit
on the files I edited or added. -
I have added tests that prove my fix is effective or that my feature works.