diff --git a/README.md b/README.md index 057f0f72a84d932a51e1e476dc05737e51fd9065..72b1c500e5a08e4009b6c57ff29dc5df3ffbe884 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ First create a new conda environment with unifrac `conda activate aam` +`conda install -c conda-forge gxx_linux-64 hdf5 mkl-include lz4 hdf5-static libcblas liblapacke make` + ## GPU Support Install CUDA 11.8 diff --git a/aam/attention_cli.py b/aam/attention_cli.py index 17c63613210d455e127813e22e4281f775c54b4e..9356eaff93f64b5c5a3bedbd05f017a990197471 100644 --- a/aam/attention_cli.py +++ b/aam/attention_cli.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime import os -from typing import Optional +from typing import Union import click import numpy as np @@ -56,10 +56,28 @@ def validate_metadata(table, metadata, missing_samples_flag): @cli.command() @click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC) @click.option("--i-tree", required=True, type=click.Path(exists=True)) +@click.option( + "--m-metadata-file", + required=True, + help="Metadata description", + type=click.Path(exists=True), +) +@click.option( + "--m-metadata-column", + required=True, + type=str, + help="Numeric metadata column to use as prediction target.", +) +@click.option( + "--p-missing-samples", + default="error", + type=click.Choice(["error", "ignore"], case_sensitive=False), + help=MISSING_SAMP_DESC, +) @click.option("--p-batch-size", default=8, show_default=True, required=False, type=int) -@click.option("--p-max-bp", required=True, type=int) @click.option("--p-epochs", default=1000, show_default=True, type=int) @click.option("--p-dropout", default=0.0, show_default=True, type=float) +@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float) @click.option("--p-patience", default=10, show_default=True, type=int) @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int) @click.option("--i-model", default=None, required=False, type=str) @@ -70,30 +88,55 @@ def validate_metadata(table, metadata, missing_samples_flag): @click.option( "--p-intermediate-activation", default="relu", show_default=True, type=str ) -@click.option("--p-asv-limit", default=512, show_default=True, type=int) +@click.option("--p-asv-limit", default=1024, show_default=True, type=int) +@click.option("--p-gen-new-table", default=True, show_default=True, type=bool) +@click.option("--p-lr", default=1e-4, show_default=True, type=float) +@click.option("--p-warmup-steps", default=10000, show_default=True, type=int) +@click.option("--p-max-bp", default=150, show_default=True, type=int) @click.option("--output-dir", required=True) +@click.option("--p-add-token", default=False, required=False, type=bool) +@click.option("--p-gotu", default=False, required=False, type=bool) +@click.option("--p-is-categorical", default=False, required=False, type=bool) +@click.option("--p-rarefy-depth", default=5000, required=False, type=int) +@click.option("--p-weight-decay", default=0.004, show_default=True, type=float) +@click.option("--p-accumulation-steps", default=1, required=False, type=int) +@click.option("--p-unifrac-metric", default="unifrac", required=False, type=str) def fit_unifrac_regressor( i_table: str, i_tree: str, + m_metadata_file: str, + m_metadata_column: str, + p_missing_samples: bool, p_batch_size: int, - p_max_bp: int, p_epochs: int, p_dropout: float, + p_asv_dropout: float, p_patience: int, p_early_stop_warmup: int, - i_model: Optional[str], + i_model: Union[None, str], p_embedding_dim: int, p_attention_heads: int, p_attention_layers: int, p_intermediate_size: int, p_intermediate_activation: str, p_asv_limit: int, + p_gen_new_table: bool, + p_lr: float, + p_warmup_steps: int, + p_max_bp: int, output_dir: str, + p_add_token: bool, + p_gotu: bool, + p_is_categorical: bool, + p_rarefy_depth: int, + p_weight_decay: float, + p_accumulation_steps, + p_unifrac_metric: str, ): from biom import load_table from aam.data_handlers import UniFracGenerator - from aam.models import UniFracEncoder + from aam.models import SequenceEncoder from aam.models.utils import cos_decay_with_warmup if not os.path.exists(output_dir): @@ -103,20 +146,39 @@ def fit_unifrac_regressor( if not os.path.exists(figure_path): os.makedirs(figure_path) + output_dim = p_embedding_dim + if p_unifrac_metric == "faith_pd": + output_dim = 1 + if i_model is not None: - model: tf.keras.Model = tf.keras.models.load_model(i_model) + model = tf.keras.models.load_model(i_model) else: - model: tf.keras.Model = UniFracEncoder( + model: tf.keras.Model = SequenceEncoder( + output_dim, p_asv_limit, - embedding_dim=p_embedding_dim, + p_unifrac_metric, dropout_rate=p_dropout, + embedding_dim=p_embedding_dim, attention_heads=p_attention_heads, attention_layers=p_attention_layers, intermediate_size=p_intermediate_size, intermediate_activation=p_intermediate_activation, + max_bp=p_max_bp, + is_16S=True, + add_token=p_add_token, + asv_dropout_rate=p_asv_dropout, + accumulation_steps=p_accumulation_steps, ) - optimizer = tf.keras.optimizers.AdamW(cos_decay_with_warmup(), beta_2=0.95) + optimizer = tf.keras.optimizers.AdamW( + cos_decay_with_warmup(p_lr, p_warmup_steps), + beta_2=0.98, + weight_decay=p_weight_decay, + # use_ema=True, + # ema_momentum=0.999, + # ema_overwrite_frequency=500, + # global_clipnorm=1.0, + ) token_shape = tf.TensorShape([None, None, 150]) count_shape = tf.TensorShape([None, None, 1]) model.build([token_shape, count_shape]) @@ -127,10 +189,14 @@ def fit_unifrac_regressor( model.summary() table = load_table(i_table) - ids = table.ids() + df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[ + [m_metadata_column] + ] + ids, table, df = validate_metadata(table, df, p_missing_samples) indices = np.arange(len(ids), dtype=np.int32) + np.random.shuffle(indices) - train_size = int(len(ids) * 0.9) + train_size = int(len(ids) * 0.8) train_indices = indices[:train_size] train_ids = ids[train_indices] @@ -140,21 +206,36 @@ def fit_unifrac_regressor( val_ids = ids[val_indices] val_table = table.filter(val_ids, inplace=False) + common_kwargs = { + "metadata_column": m_metadata_column, + "max_token_per_sample": p_asv_limit, + "rarefy_depth": p_rarefy_depth, + "batch_size": p_batch_size, + "is_16S": True, + "is_categorical": p_is_categorical, + "max_bp": p_max_bp, + "epochs": p_epochs, + "tree_path": i_tree, + "metadata": df, + "unifrac_metric": p_unifrac_metric, + } train_gen = UniFracGenerator( table=train_table, - tree_path=i_tree, - max_token_per_sample=p_asv_limit, shuffle=True, - gen_new_tables=True, - batch_size=p_batch_size, + shift=0.0, + scale=1.0, + gen_new_tables=p_gen_new_table, + **common_kwargs, ) train_data = train_gen.get_data() val_gen = UniFracGenerator( table=val_table, - tree_path=i_tree, - max_token_per_sample=p_asv_limit, shuffle=False, + shift=0.0, + scale=1.0, + gen_new_tables=False, + **common_kwargs, ) val_data = val_gen.get_data() @@ -186,10 +267,29 @@ def fit_unifrac_regressor( @cli.command() @click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC) @click.option("--i-taxonomy", required=True, type=click.Path(exists=True)) -@click.option("--p-tax-level", default=7, type=int) -@click.option("--p-max-bp", required=True, type=int) +@click.option("--i-tax-level", default=7, type=int) +@click.option( + "--m-metadata-file", + required=True, + help="Metadata description", + type=click.Path(exists=True), +) +@click.option( + "--m-metadata-column", + required=True, + type=str, + help="Numeric metadata column to use as prediction target.", +) +@click.option( + "--p-missing-samples", + default="error", + type=click.Choice(["error", "ignore"], case_sensitive=False), + help=MISSING_SAMP_DESC, +) +@click.option("--p-batch-size", default=8, show_default=True, required=False, type=int) @click.option("--p-epochs", default=1000, show_default=True, type=int) @click.option("--p-dropout", default=0.1, show_default=True, type=float) +@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float) @click.option("--p-patience", default=10, show_default=True, type=int) @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int) @click.option("--i-model", default=None, required=False, type=str) @@ -201,29 +301,53 @@ def fit_unifrac_regressor( "--p-intermediate-activation", default="relu", show_default=True, type=str ) @click.option("--p-asv-limit", default=512, show_default=True, type=int) +@click.option("--p-gen-new-table", default=True, show_default=True, type=bool) +@click.option("--p-lr", default=1e-4, show_default=True, type=float) +@click.option("--p-warmup-steps", default=10000, show_default=True, type=int) +@click.option("--p-max-bp", required=True, type=int) @click.option("--output-dir", required=True) +@click.option("--p-add-token", default=False, required=False, type=bool) +@click.option("--p-gotu", default=False, required=False, type=bool) +@click.option("--p-is-categorical", default=False, required=False, type=bool) +@click.option("--p-rarefy-depth", default=5000, required=False, type=int) +@click.option("--p-weight-decay", default=0.004, show_default=True, type=float) +@click.option("--p-accumulation-steps", default=1, required=False, type=int) def fit_taxonomy_regressor( i_table: str, i_taxonomy: str, - p_tax_level: int, - p_max_bp: int, + i_tax_level: int, + m_metadata_file: str, + m_metadata_column: str, + p_missing_samples: bool, + p_batch_size: int, p_epochs: int, p_dropout: float, + p_asv_dropout: float, p_patience: int, p_early_stop_warmup: int, - i_model: Optional[str], + i_model: Union[None, str], p_embedding_dim: int, p_attention_heads: int, p_attention_layers: int, p_intermediate_size: int, p_intermediate_activation: str, p_asv_limit: int, + p_gen_new_table: bool, + p_lr: float, + p_warmup_steps: int, + p_max_bp: int, output_dir: str, + p_add_token: bool, + p_gotu: bool, + p_is_categorical: bool, + p_rarefy_depth: int, + p_weight_decay: float, + p_accumulation_steps, ): from biom import load_table from aam.data_handlers import TaxonomyGenerator - from aam.models import TaxonomyEncoder + from aam.models import SequenceEncoder from aam.models.utils import cos_decay_with_warmup if not os.path.exists(output_dir): @@ -234,10 +358,14 @@ def fit_taxonomy_regressor( os.makedirs(figure_path) table = load_table(i_table) - ids = table.ids() + df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[ + [m_metadata_column] + ] + ids, table, df = validate_metadata(table, df, p_missing_samples) indices = np.arange(len(ids), dtype=np.int32) + np.random.shuffle(indices) - train_size = int(len(ids) * 0.9) + train_size = int(len(ids) * 0.8) train_indices = indices[:train_size] train_ids = ids[train_indices] @@ -247,22 +375,36 @@ def fit_taxonomy_regressor( val_ids = ids[val_indices] val_table = table.filter(val_ids, inplace=False) + common_kwargs = { + "metadata_column": m_metadata_column, + "max_token_per_sample": p_asv_limit, + "rarefy_depth": p_rarefy_depth, + "batch_size": p_batch_size, + "is_16S": True, + "is_categorical": p_is_categorical, + "max_bp": p_max_bp, + "epochs": p_epochs, + "taxonomy": i_taxonomy, + "tax_level": i_tax_level, + "metadata": df, + } train_gen = TaxonomyGenerator( table=train_table, - taxonomy=i_taxonomy, - tax_level=p_tax_level, - max_token_per_sample=p_asv_limit, shuffle=True, - gen_new_tables=True, + shift=0.0, + scale=1.0, + gen_new_tables=p_gen_new_table, + **common_kwargs, ) train_data = train_gen.get_data() val_gen = TaxonomyGenerator( table=val_table, - taxonomy=i_taxonomy, - tax_level=p_tax_level, - max_token_per_sample=p_asv_limit, shuffle=False, + shift=0.0, + scale=1.0, + gen_new_tables=False, + **common_kwargs, ) val_data = val_gen.get_data() @@ -283,18 +425,28 @@ def fit_taxonomy_regressor( if i_model is not None: model: tf.keras.Model = tf.keras.models.load_model(i_model) else: - model: tf.keras.Model = TaxonomyEncoder( + model: tf.keras.Model = SequenceEncoder( train_gen.num_tokens, p_asv_limit, - embedding_dim=p_embedding_dim, + "taxonomy", dropout_rate=p_dropout, + embedding_dim=p_embedding_dim, attention_heads=p_attention_heads, attention_layers=p_attention_layers, intermediate_size=p_intermediate_size, intermediate_activation=p_intermediate_activation, + max_bp=p_max_bp, + is_16S=True, + add_token=p_add_token, + asv_dropout_rate=p_asv_dropout, + accumulation_steps=p_accumulation_steps, + ) + optimizer = tf.keras.optimizers.AdamW( + cos_decay_with_warmup(p_lr, p_warmup_steps), + beta_2=0.98, + weight_decay=p_weight_decay, + # global_clipnorm=1.0, ) - - optimizer = tf.keras.optimizers.AdamW(cos_decay_with_warmup(), beta_2=0.95) token_shape = tf.TensorShape([None, None, 150]) count_shape = tf.TensorShape([None, None, 1]) model.build([token_shape, count_shape]) @@ -361,6 +513,7 @@ def fit_taxonomy_regressor( @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int) @click.option("--p-batch-size", default=8, show_default=True, required=False, type=int) @click.option("--p-dropout", default=0.1, show_default=True, type=float) +@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float) @click.option("--p-report-back", default=5, show_default=True, type=int) @click.option("--p-asv-limit", default=1024, show_default=True, type=int) @click.option("--p-penalty", default=1.0, show_default=True, type=float) @@ -380,6 +533,14 @@ def fit_taxonomy_regressor( @click.option("--p-warmup-steps", default=10000, show_default=True, type=int) @click.option("--p-max-bp", default=150, show_default=True, type=int) @click.option("--output-dir", required=True, type=click.Path(exists=False)) +@click.option("--p-output-dim", default=1, required=False, type=int) +@click.option("--p-add-token", default=False, required=False, type=bool) +@click.option("--p-gotu", default=False, required=False, type=bool) +@click.option("--p-is-categorical", default=False, required=False, type=bool) +@click.option("--p-rarefy-depth", default=5000, required=False, type=int) +@click.option("--p-weight-decay", default=0.004, show_default=True, type=float) +@click.option("--p-accumulation-steps", default=1, required=False, type=int) +@click.option("--p-unifrac-metric", default="unifrac", required=False, type=str) def fit_sample_regressor( i_table: str, i_base_model_path: str, @@ -394,6 +555,7 @@ def fit_sample_regressor( p_early_stop_warmup: int, p_batch_size: int, p_dropout: float, + p_asv_dropout: float, p_report_back: int, p_asv_limit: int, p_penalty: float, @@ -411,11 +573,22 @@ def fit_sample_regressor( p_warmup_steps, p_max_bp: int, output_dir: str, + p_output_dim: int, + p_add_token: bool, + p_gotu: bool, + p_is_categorical: bool, + p_rarefy_depth: int, + p_weight_decay: float, + p_accumulation_steps: int, + p_unifrac_metric: str, ): - from aam.callbacks import MeanAbsoluteError - from aam.data_handlers import TaxonomyGenerator, UniFracGenerator - from aam.models import SequenceRegressor, TaxonomyEncoder, UniFracEncoder + from aam.callbacks import ConfusionMatrx, MeanAbsoluteError + from aam.data_handlers import CombinedGenerator, TaxonomyGenerator, UniFracGenerator + from aam.models import SequenceRegressor + # p_is_16S = False + is_16S = not p_gotu + # p_is_categorical = True if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -428,11 +601,14 @@ def fit_sample_regressor( os.makedirs(model_path) table = load_table(i_table) - df = pd.read_csv(m_metadata_file, sep="\t", index_col=0)[[m_metadata_column]] + df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[ + [m_metadata_column] + ] ids, table, df = validate_metadata(table, df, p_missing_samples) num_ids = len(ids) fold_indices = np.arange(num_ids) + np.random.shuffle(fold_indices) if p_test_size > 0: test_size = int(num_ids * p_test_size) train_size = num_ids - test_size @@ -443,10 +619,11 @@ def fit_sample_regressor( common_kwargs = { "metadata_column": m_metadata_column, - "is_categorical": False, "max_token_per_sample": p_asv_limit, - "rarefy_depth": 5000, + "rarefy_depth": p_rarefy_depth, "batch_size": p_batch_size, + "is_16S": is_16S, + "is_categorical": p_is_categorical, } def tax_gen(table, df, shuffle, shift, scale, epochs, gen_new_tables): @@ -475,28 +652,48 @@ def fit_sample_regressor( epochs=epochs, gen_new_tables=gen_new_tables, max_bp=p_max_bp, + unifrac_metric=p_unifrac_metric, **common_kwargs, ) - if p_taxonomy is not None and p_tree is None: + def combine_gen(table, df, shuffle, shift, scale, epochs, gen_new_tables): + return CombinedGenerator( + table=table, + metadata=df, + tree_path=p_tree, + taxonomy=p_taxonomy, + tax_level=p_taxonomy_level, + shuffle=shuffle, + shift=shift, + scale=scale, + epochs=epochs, + gen_new_tables=gen_new_tables, + max_bp=p_max_bp, + **common_kwargs, + ) + + if p_unifrac_metric == "combined": + base_model = "combined" + generator = combine_gen + elif p_taxonomy is not None and p_tree is None: base_model = "taxonomy" generator = tax_gen elif p_taxonomy is None and p_tree is not None: - base_model = "unifrac" + base_model = p_unifrac_metric generator = unifrac_gen else: raise Exception("Only taxonomy or UniFrac is supported.") if i_base_model_path is not None: base_model = tf.keras.models.load_model(i_base_model_path, compile=False) - if isinstance(base_model, TaxonomyEncoder): + base_type = base_model.encoder_type + if base_type == "taxonomy": generator = tax_gen - elif isinstance(base_model, UniFracEncoder): - generator = unifrac_gen else: - raise Exception(f"Unsupported base model {type(base_model)}") + generator = unifrac_gen + if not p_no_freeze_base_weights: - raise Warning("base_model's weights are set to trainable.") + print("base_model's weights are set to trainable.") def _get_fold( indices, @@ -521,15 +718,24 @@ def fit_sample_regressor( data["num_tokens"] = None return data - kfolds = KFold(p_cv) + if not p_is_categorical: + print("non-stratified folds") + kfolds = KFold(p_cv) + splits = kfolds.split(fold_indices) + else: + print("stratified folds...") + kfolds = StratifiedKFold(p_cv) + train_ids = ids[fold_indices] + train_classes = df.loc[df.index.isin(train_ids), m_metadata_column].values + splits = kfolds.split(fold_indices, train_classes) models = [] - for i, (train_ind, val_ind) in enumerate(kfolds.split(fold_indices)): + for i, (train_ind, val_ind) in enumerate(splits): train_data = _get_fold( train_ind, shuffle=True, shift=0.0, - scale=100.0, + scale=1.0, gen_new_tables=p_gen_new_table, ) val_data = _get_fold( @@ -542,30 +748,74 @@ def fit_sample_regressor( with open(os.path.join(model_path, f"f{i}_val_ids.txt"), "w") as f: for id in ids[val_ind]: f.write(id + "\n") + vocab_size = 6 if not p_is_categorical else 2000 + + if base_model == "combined": + base_output_dim = [p_embedding_dim, 1, train_data["num_tokens"]] + elif base_model == "unifrac": + base_output_dim = p_embedding_dim + elif base_model == "faith_pd": + base_output_dim = 1 + else: + base_output_dim = train_data["num_tokens"] model = SequenceRegressor( token_limit=p_asv_limit, + base_output_dim=base_output_dim, + shift=train_data["shift"], + scale=train_data["scale"], + dropout_rate=p_dropout, embedding_dim=p_embedding_dim, attention_heads=p_attention_heads, attention_layers=p_attention_layers, intermediate_size=p_intermediate_size, intermediate_activation=p_intermediate_activation, - shift=train_data["shift"], - scale=train_data["scale"], - dropout_rate=p_dropout, base_model=base_model, freeze_base=p_no_freeze_base_weights, - num_tax_levels=train_data["num_tokens"], penalty=p_penalty, nuc_penalty=p_nuc_penalty, max_bp=p_max_bp, + is_16S=is_16S, + vocab_size=vocab_size, + out_dim=p_output_dim, + classifier=p_is_categorical, + add_token=p_add_token, + class_weights=train_data["class_weights"], + accumulation_steps=p_accumulation_steps, ) token_shape = tf.TensorShape([None, None, p_max_bp]) count_shape = tf.TensorShape([None, None, 1]) model.build([token_shape, count_shape]) model.summary() - loss = tf.keras.losses.MeanSquaredError(reduction="none") + fold_label = i + 1 + if not p_is_categorical: + loss = tf.keras.losses.MeanSquaredError(reduction="none") + callbacks = [ + MeanAbsoluteError( + monitor="val_mae", + dataset=val_data["dataset"], + output_dir=os.path.join( + figure_path, f"model_f{fold_label}-val.png" + ), + report_back=p_report_back, + ) + ] + else: + loss = tf.keras.losses.CategoricalFocalCrossentropy( + from_logits=False, reduction="none" + ) + # loss = tf.keras.losses.CategoricalHinge(reduction="none") + callbacks = [ + ConfusionMatrx( + monitor="val_target_loss", + dataset=val_data["dataset"], + output_dir=os.path.join( + figure_path, f"model_f{fold_label}-val.png" + ), + report_back=p_report_back, + ) + ] model_cv = CVModel( model, train_data, @@ -573,25 +823,18 @@ def fit_sample_regressor( output_dir, fold_label, ) + metric = "mae" if not p_is_categorical else "target_loss" model_cv.fit_fold( loss, p_epochs, os.path.join(model_path, f"model_f{fold_label}.keras"), - metric="mae", + metric=metric, patience=p_patience, early_stop_warmup=p_early_stop_warmup, - callbacks=[ - MeanAbsoluteError( - monitor="val_mae", - dataset=val_data["dataset"], - output_dir=os.path.join( - figure_path, f"model_f{fold_label}-val.png" - ), - report_back=p_report_back, - ) - ], + callbacks=[*callbacks], lr=p_lr, warmup_steps=p_warmup_steps, + weight_decay=p_weight_decay, ) models.append(model_cv) print(f"Fold {i+1} mae: {model_cv.metric_value}") diff --git a/aam/callbacks.py b/aam/callbacks.py index 915cc4051881d138996366f7e82c8dcbcb90fbc7..f0e1a9f166d0ee034d6ec56a445c046a7744b299 100644 --- a/aam/callbacks.py +++ b/aam/callbacks.py @@ -37,17 +37,17 @@ def _mean_absolute_error(pred_val, true_val, fname, labels=None): def _confusion_matrix(pred_val, true_val, fname, cat_labels=None): cf_matrix = tf.math.confusion_matrix(true_val, pred_val).numpy() group_counts = ["{0:0.0f}".format(value) for value in cf_matrix.flatten()] - group_percentages = [ - "{0:.2%}".format(value) for value in cf_matrix.flatten() / np.sum(cf_matrix) - ] + + cf_matrix = cf_matrix / np.sum(cf_matrix, axis=-1, keepdims=True) + group_percentages = ["{0:.2%}".format(value) for value in cf_matrix.flatten()] labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_percentages)] labels = np.asarray(labels).reshape(cf_matrix.shape) fig, ax = plt.subplots(figsize=(10, 10)) ax = sns.heatmap( cf_matrix, annot=labels, - xticklabels=cat_labels, - yticklabels=cat_labels, + # xticklabels=cat_labels, + # yticklabels=cat_labels, fmt="", ) import textwrap @@ -85,16 +85,30 @@ class MeanAbsoluteError(tf.keras.callbacks.Callback): class ConfusionMatrx(tf.keras.callbacks.Callback): - def __init__(self, dataset, output_dir, report_back, labels, **kwargs): + def __init__( + self, + dataset, + output_dir, + report_back, + labels=None, + monitor="val_loss", + **kwargs, + ): super().__init__(**kwargs) self.output_dir = output_dir self.report_back = report_back self.dataset = dataset self.labels = labels + self.best_metric = None + self.monitor = monitor def on_epoch_end(self, epoch, logs=None): - y_pred, y_true = self.model.predict(self.dataset) - _confusion_matrix(y_pred, y_true, self.output_dir, self.labels) + metric = logs[self.monitor] + print(self.best_metric, metric) + if self.best_metric is None or self.best_metric > metric: + y_pred, y_true = self.model.predict(self.dataset) + _confusion_matrix(y_pred, y_true, self.output_dir, self.labels) + self.best_metric = metric class SaveModel(tf.keras.callbacks.Callback): @@ -107,11 +121,8 @@ class SaveModel(tf.keras.callbacks.Callback): self.monitor = monitor def on_epoch_end(self, epoch, logs=None): - learning_rate = float( - tf.keras.backend.get_value(self.model.optimizer.learning_rate) - ) - # Add the learning rate to the logs dictionary - logs["learning_rate"] = learning_rate + iterations = float(tf.keras.backend.get_value(self.model.optimizer.iterations)) + logs["iteration"] = iterations metric = logs[self.monitor] if self.best_weights is None or self.best_metric > metric: diff --git a/aam/cv_utils.py b/aam/cv_utils.py index eb22607851e2df1670a002ab97c638ed3ca5048e..5ecf2708ecab54dab01b1c076180b6c7b9171c82 100644 --- a/aam/cv_utils.py +++ b/aam/cv_utils.py @@ -30,33 +30,41 @@ class CVModel: def fit_fold( self, - loss, - epochs, - model_save_path, - metric="mae", - patience=10, - early_stop_warmup=50, - callbacks=[], - lr=1e-4, - warmup_steps=10000, + loss: tf.keras.losses.Loss, + epochs: int, + model_save_path: str, + metric: str = "loss", + patience: int = 10, + early_stop_warmup: int = 50, + callbacks: list[tf.keras.callbacks.Callback] = [], + lr: float = 1e-4, + warmup_steps: int = 10000, + weight_decay: float = 0.004, ): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) - optimizer = tf.keras.optimizers.AdamW( - cos_decay_with_warmup(lr, warmup_steps), beta_2=0.95 + print(f"weight decay: {weight_decay}") + optimizer = tf.keras.optimizers.Adam( + cos_decay_with_warmup(lr, warmup_steps), + beta_2=0.98, + # weight_decay=weight_decay, + # global_clipnorm=1.0, + # use_ema=True, + # ema_momentum=0.999, + # ema_overwrite_frequency=500, ) model_saver = SaveModel(model_save_path, 10, f"val_{metric}") core_callbacks = [ tf.keras.callbacks.TensorBoard( - log_dir=self.log_dir, - histogram_freq=0, + log_dir=self.log_dir, histogram_freq=0, write_graph=False ), tf.keras.callbacks.EarlyStopping( "val_loss", patience=patience, start_from_epoch=early_stop_warmup ), model_saver, ] - self.model.compile(optimizer=optimizer, loss=loss, run_eagerly=False) + self.model.compile(optimizer=optimizer, loss=loss) + # Set up the summary writer self.model.fit( self.train_data["dataset"], validation_data=self.val_data["dataset"], diff --git a/aam/data_handlers/__init__.py b/aam/data_handlers/__init__.py index 57a7ded36556f1939af53b3d754f2def90981eb9..8650d764315c3585c9b9828632404ae76b2723a4 100644 --- a/aam/data_handlers/__init__.py +++ b/aam/data_handlers/__init__.py @@ -1,11 +1,13 @@ from __future__ import annotations +from .combined_generator import CombinedGenerator from .generator_dataset import GeneratorDataset from .sequence_dataset import SequenceDataset from .taxonomy_generator import TaxonomyGenerator from .unifrac_generator import UniFracGenerator __all__ = [ + "CombinedGenerator", "SequenceDataset", "TaxonomyGenerator", "UniFracGenerator", diff --git a/aam/data_handlers/combined_generator.py b/aam/data_handlers/combined_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..39b0becb45fcb3f4b0928f003edfb697d16d4796 --- /dev/null +++ b/aam/data_handlers/combined_generator.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import os +from typing import Iterable, Union + +import numpy as np +import pandas as pd +import tensorflow as tf +from biom import Table +from biom.util import biom_open +from skbio import DistanceMatrix +from sklearn import preprocessing +from unifrac import faith_pd, unweighted + +from aam.data_handlers.generator_dataset import GeneratorDataset, add_lock + + +class CombinedGenerator(GeneratorDataset): + taxon_field = "Taxon" + levels = [f"Level {i}" for i in range(1, 8)] + taxonomy_fn = "taxonomy.tsv" + + def __init__(self, tree_path: str, taxonomy, tax_level, **kwargs): + super().__init__(**kwargs) + self.tree_path = tree_path + if self.batch_size % 2 != 0: + raise Exception("Batch size must be multiple of 2") + + self.tax_level = f"Level {tax_level}" + self.taxonomy = taxonomy + + self.encoder_target = self._create_encoder_target(self.rarefy_table) + self.encoder_dtype = (np.float32, np.float32, np.float32, np.int32) + self.encoder_output_type = ( + tf.TensorSpec(shape=[self.batch_size, self.batch_size], dtype=tf.float32), + tf.TensorSpec(shape=[self.batch_size, 1], dtype=tf.float32), + tf.TensorSpec(shape=[self.batch_size, None], dtype=tf.int32), + ) + + def _create_encoder_target(self, table: Table) -> DistanceMatrix: + if not hasattr(self, "tree_path"): + return None + + random = np.random.random(1)[0] + temp_path = f"/tmp/temp{random}.biom" + with biom_open(temp_path, "w") as f: + table.to_hdf5(f, "aam") + uni = unweighted(temp_path, self.tree_path) + faith = faith_pd(temp_path, self.tree_path) + os.remove(temp_path) + + obs = table.ids(axis="observation") + tax = self.taxonomy.loc[obs, "token"] + return (uni, faith, tax) + + def _encoder_output( + self, + encoder_target: Iterable[object], + sample_ids: Iterable[str], + obs_ids: list[str], + ) -> np.ndarray[float]: + uni, faith, tax = encoder_target + uni = uni.filter(sample_ids).data + faith = faith.loc[sample_ids].to_numpy().reshape((-1, 1)) + + tax_tokens = [tax.loc[obs] for obs in obs_ids] + max_len = max([len(tokens) for tokens in tax_tokens]) + tax_tokens = np.array([np.pad(t, [[0, max_len - len(t)]]) for t in tax_tokens]) + return (uni, faith, tax_tokens) + + @property + def taxonomy(self) -> pd.DataFrame: + return self._taxonomy + + @taxonomy.setter + @add_lock + def taxonomy(self, taxonomy: Union[str, pd.DataFrame]): + if hasattr(self, "_taxon_set"): + raise Exception("Taxon already set") + if taxonomy is None: + self._taxonomy = taxonomy + return + + if isinstance(taxonomy, str): + taxonomy = pd.read_csv(taxonomy, sep="\t", index_col=0) + if self.taxon_field not in taxonomy.columns: + raise Exception("Invalid taxonomy: missing 'Taxon' field") + + taxonomy[self.levels] = taxonomy[self.taxon_field].str.split("; ", expand=True) + taxonomy = taxonomy.loc[self._table.ids(axis="observation")] + + if self.tax_level not in self.levels: + raise Exception(f"Invalid level: {self.tax_level}") + + level_index = self.levels.index(self.tax_level) + levels = self.levels[: level_index + 1] + taxonomy = taxonomy.loc[:, levels] + taxonomy.loc[:, "class"] = taxonomy.loc[:, levels].agg("; ".join, axis=1) + + le = preprocessing.LabelEncoder() + taxonomy.loc[:, "token"] = le.fit_transform(taxonomy["class"]) + taxonomy.loc[:, "token"] += 1 # shifts tokens to be between 1 and n + print( + "min token:", min(taxonomy["token"]), "max token:", max(taxonomy["token"]) + ) + self.num_tokens = ( + max(taxonomy["token"]) + 1 + ) # still need to add 1 to account for shift + self._taxonomy = taxonomy + + +if __name__ == "__main__": + import numpy as np + + from aam.data_handlers import CombinedGenerator + + ug = CombinedGenerator( + table="/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-no-duplicate-host-bloom-filtered-5000-small-stool-only-very-small.biom", + tree_path="/home/kalen/aam-research-exam/research-exam/agp/data/agp-aligned.nwk", + taxonomy="/home/kalen/aam-research-exam/research-exam/healty-age-regression/taxonomy.tsv", + tax_level=7, + metadata="/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-healthy.txt", + metadata_column="host_age", + shift=0.0, + scale=100.0, + gen_new_tables=True, + ) + data = ug.get_data() + for i, (x, y) in enumerate(data["dataset"]): + print(y) + break + + # data = ug.get_data_by_id(ug.rarefy_tables.ids()[:16]) + # for x, y in data["dataset"]: + # print(y) + # break diff --git a/aam/data_handlers/generator_dataset.py b/aam/data_handlers/generator_dataset.py index 9ff4f0da8bfe89878b3e470a7c6b813c3ea15dd2..1c776bc19146a3483fda1380df24ce70ca4318f5 100644 --- a/aam/data_handlers/generator_dataset.py +++ b/aam/data_handlers/generator_dataset.py @@ -55,7 +55,6 @@ class GeneratorDataset: table: Union[str, Table], metadata: Optional[Union[str, pd.DataFrame]] = None, metadata_column: Optional[str] = None, - is_categorical: Optional[bool] = None, shift: Optional[Union[str, float]] = None, scale: Union[str, float] = "minmax", max_token_per_sample: int = 1024, @@ -65,12 +64,15 @@ class GeneratorDataset: gen_new_tables: bool = False, batch_size: int = 8, max_bp: int = 150, + is_16S: bool = True, + is_categorical: Optional[bool] = None, ): table, metadata self.table = table self.metadata_column = metadata_column self.is_categorical = is_categorical + self.include_sample_weight = is_categorical self.shift = shift self.scale = scale self.metadata = metadata @@ -84,6 +86,7 @@ class GeneratorDataset: self.batch_size = batch_size self.max_bp = max_bp + self.is_16S = is_16S self.preprocessed_table = self.table self.obs_ids = self.preprocessed_table.ids(axis="observation") @@ -142,7 +145,7 @@ class GeneratorDataset: return self._validate_dataframe(metadata) if isinstance(metadata, str): - metadata = pd.read_csv(metadata, sep="\t", index_col=0) + metadata = pd.read_csv(metadata, sep="\t", index_col=0, dtype={0: str}) if self.metadata_column not in metadata.columns: raise Exception(f"Invalid metadata column {self.metadata_column}") @@ -200,6 +203,9 @@ class GeneratorDataset: if not (y_data, pd.Series): raise Exception(f"Invalid y_data object: {type(y_data)}") + if not self.is_categorical: + return y_data.loc[sample_ids].to_numpy().reshape(-1, 1) + return y_data.loc[sample_ids].to_numpy().reshape(-1, 1) def _create_table_data( @@ -208,10 +214,12 @@ class GeneratorDataset: obs_ids = table.ids(axis="observation") sample_ids = table.ids() - obs_encodings = tf.cast( - tf.strings.unicode_decode(obs_ids, "UTF-8"), dtype=tf.int64 - ) - obs_encodings = self.lookup_table.lookup(obs_encodings).numpy() + obs_encodings = None + if self.is_16S: + obs_encodings = tf.cast( + tf.strings.unicode_decode(obs_ids, "UTF-8"), dtype=tf.int64 + ) + obs_encodings = self.lookup_table.lookup(obs_encodings).numpy() table_data, row, col, shape = self._table_data(table) @@ -233,7 +241,11 @@ class GeneratorDataset: s_mask = row == s s_counts = counts[s_mask] s_obs_indices = col[s_mask] - s_tokens = obs_encodings[s_obs_indices] + if self.is_16S: + s_tokens = obs_encodings[s_obs_indices] + else: + s_tokens = np.reshape(s_obs_indices, newshape=(-1, 1)) + 1 + sorted_order = np.argsort(s_counts) sorted_order = sorted_order[::-1] s_counts = s_counts[sorted_order].reshape(-1, 1) @@ -248,7 +260,7 @@ class GeneratorDataset: if s_max_token > self.max_token_per_sample: print(f"\tskipping group due to exceeding token limit {s_max_token}...") - return None, None, None, None + return None, None, None, None, None s_ids = [sample_ids[s] for s in samples] @@ -260,7 +272,7 @@ class GeneratorDataset: encoder_target = self.encoder_target encoder_output = self._encoder_output(encoder_target, s_ids, s_obj_ids) - return s_counts, s_tokens, y_output, encoder_output + return s_counts, s_tokens, y_output, encoder_output, s_obj_ids def _epoch_complete(self, processed): if processed < self.steps_per_epoch: @@ -298,7 +310,7 @@ class GeneratorDataset: return table_data, y_data, encoder_target, sample_indices - def _create_epoch_generator(self): + def _create_epoch_generator(self, include_seq_id): def generator(): processed = 0 table_data = self.table_data @@ -322,7 +334,9 @@ class GeneratorDataset: ) while not self._epoch_complete(processed): - counts, tokens, y_output, encoder_out = sample_data(minibatch) + counts, tokens, y_output, encoder_out, ob_ids = sample_data( + minibatch + ) if counts is not None: max_len = max([len(c) for c in counts]) @@ -332,6 +346,12 @@ class GeneratorDataset: padded_tokens = np.array( [np.pad(t, [[0, max_len - len(t)], [0, 0]]) for t in tokens] ) + padded_ob_ids = np.array( + [ + np.pad(o, [[0, max_len - len(o)]], constant_values="") + for o in ob_ids + ] + ) processed += 1 table_output = ( @@ -344,14 +364,29 @@ class GeneratorDataset: output = y_output.astype(np.float32) if encoder_out is not None: - if output is None: - output = encoder_out.astype(self.encoder_dtype) + if isinstance(encoder_out, tuple): + encoder_out = tuple( + [ + o.astype(t) + for o, t in zip(encoder_out, self.encoder_dtype) + ] + ) else: + encoder_out = encoder_out.astype(self.encoder_dtype) + + if output is not None: output = ( output, - encoder_out.astype(self.encoder_dtype), + encoder_out, ) + else: + output = encoder_out + if include_seq_id: + output = ( + *output, + padded_ob_ids, + ) if output is not None: yield (table_output, output) else: @@ -386,8 +421,8 @@ class GeneratorDataset: return generator - def get_data(self): - generator = self._create_epoch_generator() + def get_data(self, include_seq_id=False): + generator = self._create_epoch_generator(include_seq_id) output_sig = ( tf.TensorSpec(shape=[self.batch_size, None, self.max_bp], dtype=tf.int32), tf.TensorSpec(shape=[self.batch_size, None, 1], dtype=tf.int32), @@ -404,7 +439,23 @@ class GeneratorDataset: y_output_sig = (y_output_sig, self.encoder_output_type) if y_output_sig is not None: + if include_seq_id: + y_output_sig = ( + *y_output_sig, + tf.TensorSpec( + shape=(self.batch_size, None), dtype=tf.string, name=None + ), + ) output_sig = (output_sig, y_output_sig) + class_weights = None + if self.is_categorical: + counts = self._metadata.value_counts().to_dict() + counts[0.0] = 0 + class_weights = [0] * len(counts) + for k, v in counts.items(): + class_weights[int(k)] = 1 / np.sqrt(v) + class_weights[0] = 0 + print(class_weights) dataset: tf.data.Dataset = tf.data.Dataset.from_generator( generator, output_signature=output_sig, @@ -417,6 +468,7 @@ class GeneratorDataset: "scale": self.scale, "size": self.size, "steps_pre_epoch": self.steps_per_epoch, + "class_weights": class_weights, } return data_obj diff --git a/aam/data_handlers/tests/generator_dataset_test.py b/aam/data_handlers/tests/generator_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33a2b509294294b3c86d514d087d4a59e8af52b9 --- /dev/null +++ b/aam/data_handlers/tests/generator_dataset_test.py @@ -0,0 +1,45 @@ +import numpy as np + +from aam.data_handlers.generator_dataset import GeneratorDataset + + +def test_generator_dataset(): + """Generator should output a single + tensor for y. Both iteration should also + return same value. + """ + + table = "/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-no-duplicate-host-bloom-filtered-5000-small-stool-only-very-small.biom" + metadata = "/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-healthy.txt" + generator = GeneratorDataset( + table=table, + metadata=metadata, + metadata_column="host_age", + gen_new_tables=False, + shuffle=False, + epochs=1, + ) + + table = generator.table + metadata = generator.metadata + assert table.shape[1] == metadata.shape[0] + assert np.all(np.equal(table.ids(), metadata.index)) + + batch_size = 8 + s_ids = table.ids()[: 5 * batch_size] + y_true = metadata.loc[s_ids] + + data1 = generator.get_data() + data2 = generator.get_data() + + data_ys = [] + for ((token1, count1), y1), ((token2, count2), y2) in zip( + data1["dataset"].take(5), data2["dataset"].take(5) + ): + assert np.all(np.equal(token1.numpy(), token2.numpy())) + assert np.all(np.equal(count1.numpy(), count2.numpy())) + assert np.all(np.equal(y1.numpy(), y2.numpy())) + data_ys.append(y1) + + data_ys = np.concatenate(data_ys) + assert np.all(np.equal(y_true.to_numpy().reshape(-1), data_ys.reshape(-1))) diff --git a/aam/data_handlers/unifrac_generator.py b/aam/data_handlers/unifrac_generator.py index ef4e48ac74a55604dd705c325348fa9d33569f7a..3e6dc2a840dbb5544e54a81d286a45c5a5572961 100644 --- a/aam/data_handlers/unifrac_generator.py +++ b/aam/data_handlers/unifrac_generator.py @@ -8,22 +8,29 @@ import tensorflow as tf from biom import Table from biom.util import biom_open from skbio import DistanceMatrix -from unifrac import unweighted +from unifrac import faith_pd, unweighted from aam.data_handlers.generator_dataset import GeneratorDataset class UniFracGenerator(GeneratorDataset): - def __init__(self, tree_path: str, **kwargs): + def __init__(self, tree_path: str, unifrac_metric="unifrac", **kwargs): super().__init__(**kwargs) self.tree_path = tree_path + self.unifrac_metric = unifrac_metric + if self.batch_size % 2 != 0: raise Exception("Batch size must be multiple of 2") self.encoder_target = self._create_encoder_target(self.rarefy_table) self.encoder_dtype = np.float32 - self.encoder_output_type = tf.TensorSpec( - shape=[self.batch_size, self.batch_size], dtype=tf.float32 - ) + if self.unifrac_metric == "unifrac": + self.encoder_output_type = tf.TensorSpec( + shape=[self.batch_size, self.batch_size], dtype=tf.float32 + ) + else: + self.encoder_output_type = tf.TensorSpec( + shape=[self.batch_size, 1], dtype=tf.float32 + ) def _create_encoder_target(self, table: Table) -> DistanceMatrix: if not hasattr(self, "tree_path"): @@ -33,7 +40,10 @@ class UniFracGenerator(GeneratorDataset): temp_path = f"/tmp/temp{random}.biom" with biom_open(temp_path, "w") as f: table.to_hdf5(f, "aam") - distances = unweighted(temp_path, self.tree_path) + if self.unifrac_metric == "unifrac": + distances = unweighted(temp_path, self.tree_path) + else: + distances = faith_pd(temp_path, self.tree_path) os.remove(temp_path) return distances @@ -43,10 +53,15 @@ class UniFracGenerator(GeneratorDataset): sample_ids: Iterable[str], ob_ids: list[str], ) -> np.ndarray[float]: - return encoder_target.filter(sample_ids).data + if self.unifrac_metric == "unifrac": + return encoder_target.filter(sample_ids).data + else: + return encoder_target.loc[sample_ids].to_numpy().reshape((-1, 1)) if __name__ == "__main__": + import numpy as np + from aam.data_handlers import UniFracGenerator ug = UniFracGenerator( @@ -60,7 +75,7 @@ if __name__ == "__main__": ) data = ug.get_data() for i, (x, y) in enumerate(data["dataset"]): - print(y) + print(np.mean(y[1])) break # data = ug.get_data_by_id(ug.rarefy_tables.ids()[:16]) diff --git a/aam/layers.py b/aam/layers.py index 36848394ee2c21d46228fb5097054262fc23f70d..94d92053f82d99fa986922fe79cda07114c30ba5 100644 --- a/aam/layers.py +++ b/aam/layers.py @@ -38,6 +38,7 @@ class ASVEncoder(tf.keras.layers.Layer): dropout_rate, intermediate_ff, intermediate_activation="gelu", + add_token=True, **kwargs, ): super(ASVEncoder, self).__init__(**kwargs) @@ -47,13 +48,23 @@ class ASVEncoder(tf.keras.layers.Layer): self.dropout_rate = dropout_rate self.intermediate_ff = intermediate_ff self.intermediate_activation = intermediate_activation + self.add_token = add_token self.base_tokens = 6 self.num_tokens = self.base_tokens * self.max_bp + 2 + + self.asv_token = self.num_tokens - 1 + self.nucleotide_position = tf.range( + 0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32 + ) + + def build(self, input_shape): self.emb_layer = tf.keras.layers.Embedding( self.num_tokens, - 32, + 128, input_length=self.max_bp, - embeddings_initializer=tf.keras.initializers.GlorotNormal(), + embeddings_initializer=tf.keras.initializers.RandomNormal( + mean=0, stddev=128**0.5 + ), ) self.avs_attention = NucleotideAttention( @@ -61,20 +72,18 @@ class ASVEncoder(tf.keras.layers.Layer): num_heads=self.attention_heads, num_layers=self.attention_layers, dropout=self.dropout_rate, - intermediate_ff=intermediate_ff, + intermediate_ff=self.intermediate_ff, intermediate_activation=self.intermediate_activation, ) - self.asv_token = self.num_tokens - 1 - self.nucleotide_position = tf.range( - 0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32 - ) + super(ASVEncoder, self).build(input_shape) def call(self, inputs, training=False): seq = inputs seq = seq + self.nucleotide_position # add <ASV> token - seq = tf.pad(seq, [[0, 0], [0, 0], [0, 1]], constant_values=self.asv_token) + if self.add_token: + seq = tf.pad(seq, [[0, 0], [0, 0], [0, 1]], constant_values=self.asv_token) output = self.emb_layer(seq) output = self.avs_attention(output, training=training) @@ -97,6 +106,7 @@ class ASVEncoder(tf.keras.layers.Layer): "dropout_rate": self.dropout_rate, "intermediate_ff": self.intermediate_ff, "intermediate_activation": self.intermediate_activation, + "add_token": self.add_token, } ) return config @@ -196,8 +206,13 @@ class NucleotideAttention(tf.keras.layers.Layer): self.epsilon = 1e-6 self.intermediate_ff = intermediate_ff self.intermediate_activation = intermediate_activation + + def build(self, input_shape): self.pos_emb = tfm.nlp.layers.PositionEmbedding( - self.max_bp + 1, seq_axis=2, name="nuc_pos" + self.max_bp + 1, + seq_axis=2, + initializer=tf.keras.initializers.RandomNormal(mean=0, stddev=128**0.5), + name="nuc_pos", ) self.attention_layers = [] for i in range(self.num_layers): @@ -206,7 +221,7 @@ class NucleotideAttention(tf.keras.layers.Layer): num_heads=self.num_heads, dropout=self.dropout, epsilon=self.epsilon, - intermediate_ff=intermediate_ff, + intermediate_ff=self.intermediate_ff, intermediate_activation=self.intermediate_activation, name=("layer_%d" % i), ) @@ -214,15 +229,17 @@ class NucleotideAttention(tf.keras.layers.Layer): self.output_normalization = tf.keras.layers.LayerNormalization( epsilon=self.epsilon, dtype=tf.float32 ) + super(NucleotideAttention, self).build(input_shape) def call(self, attention_input, attention_mask=None, training=False): attention_input = attention_input + self.pos_emb(attention_input) + attention_input = attention_input * (9 * 3) ** (-0.25) for layer_idx in range(self.num_layers): attention_input = self.attention_layers[layer_idx]( attention_input, training=training ) - output = self.output_normalization(attention_input) - return output + # output = self.output_normalization(attention_input) + return attention_input def get_config(self): config = super(NucleotideAttention, self).get_config() @@ -255,6 +272,14 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer): self.dropout = dropout self.epsilon = epsilon self.intermediate_ff = intermediate_ff + self.intermediate_activation = intermediate_activation + + def build(self, input_shape): + self._shape = input_shape + self.nucleotides = input_shape[2] + self.hidden_dim = input_shape[3] + self.head_size = tf.cast(self.hidden_dim / self.num_heads, dtype=tf.int32) + self.attention_norm = tf.keras.layers.LayerNormalization( epsilon=self.epsilon, dtype=tf.float32 ) @@ -263,19 +288,23 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer): self.ff_norm = tf.keras.layers.LayerNormalization( epsilon=self.epsilon, dtype=tf.float32 ) - self.intermediate_activation = intermediate_activation - - def build(self, input_shape): - self._shape = input_shape - self.nucleotides = input_shape[2] - self.hidden_dim = input_shape[3] - self.head_size = tf.cast(self.hidden_dim / self.num_heads, dtype=tf.int32) + self.nuc_alpha = self.add_weight( + name="nuc_alpha", + initializer=tf.keras.initializers.Zeros(), + trainable=True, + dtype=tf.float32, + ) wi_shape = [1, 1, self.num_heads, self.hidden_dim, self.head_size] self.w_qi = self.add_weight("w_qi", wi_shape, trainable=True, dtype=tf.float32) self.w_ki = self.add_weight("w_ki", wi_shape, trainable=True, dtype=tf.float32) self.w_vi = self.add_weight("w_kv", wi_shape, trainable=True, dtype=tf.float32) - self.o_dense = tf.keras.layers.Dense(self.hidden_dim, use_bias=False) + + wo_shape = [1, 1, self.hidden_dim, self.hidden_dim] + self.o_dense = self.add_weight( + "w_o", wo_shape, trainable=True, dtype=tf.float32 + ) + # self.o_dense = tf.keras.layers.Dense(self.hidden_dim, use_bias=False) self.scale_dot_factor = tf.math.sqrt( tf.cast(self.head_size, dtype=self.compute_dtype) @@ -303,7 +332,7 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer): def scaled_dot_attention(self, attention_input): wq_tensor = self.compute_wi(attention_input, self.w_qi) wk_tensor = self.compute_wi(attention_input, self.w_ki) - wv_tensor = self.compute_wi(attention_input, self.w_vi) + wv_tensor = self.compute_wi(attention_input, self.w_vi * (0.67 * 3) ** -0.25) # (multihead) scaled dot product attention sublayer # [B, A, H, N, S] => [B, A, H, N, N] @@ -332,20 +361,23 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer): attention_output, shape=[batch_size, num_asv, self.nucleotides, self.hidden_dim], ) - attention_output = self.o_dense(attention_output) + attention_output = tf.matmul( + attention_output, self.o_dense * (0.67 * 3) ** (-0.25) + ) + # attention_output = self.o_dense(attention_output) attention_output = tf.ensure_shape(attention_output, self._shape) return attention_output def call(self, attention_input, training=False): - # scaled dot product attention sublayer - attention_input = self.attention_norm(attention_input) + # # scaled dot product attention sublayer + # attention_input = self.attention_norm(attention_input) # cast for mixed precision _attention_input = tf.cast(attention_input, dtype=self.compute_dtype) # cast back to float32 _attention_output = self.scaled_dot_attention(_attention_input) - attention_output = tf.cast(_attention_output, dtype=tf.float32) + attention_output = tf.cast(_attention_output, dtype=tf.float32) * self.nuc_alpha # residual connection attention_output = tf.add(attention_input, attention_output) @@ -353,13 +385,14 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer): attention_output = self.attention_dropout(attention_output, training=training) # cast for mixed precision - ff_input = self.ff_norm(attention_output) + # ff_input = self.ff_norm(attention_output) + ff_input = attention_output # self.ff_norm(attention_output) _ff_input = tf.cast(ff_input, dtype=self.compute_dtype) _ff_output = self.inter_ff(_ff_input) _ff_output = self.outer_ff(_ff_output) - # cast back to float32 - ff_output = tf.cast(_ff_output, dtype=tf.float32) + # cast back to float32, residual connection + ff_output = tf.cast(_ff_output, dtype=tf.float32) * self.nuc_alpha ff_output = tf.add(ff_input, ff_output) ff_output = tf.ensure_shape(ff_output, self._shape) diff --git a/aam/losses.py b/aam/losses.py index 575a370877cda3dc89939efc24cfea5593f35b66..68318b275434a35ba6470917aa54fceaef57eafa 100644 --- a/aam/losses.py +++ b/aam/losses.py @@ -61,7 +61,7 @@ def _pairwise_distances(embeddings, squared=False): # (ex: on the diagonal) # we need to add a small epsilon where distances == 0.0 mask = tf.cast(tf.equal(distances, 0.0), tf.float32) - distances = distances + mask * 1e-07 + distances = distances + mask * 1e-12 distances = tf.sqrt(distances) @@ -79,7 +79,6 @@ class PairwiseLoss(tf.keras.losses.Loss): def call(self, y_true, y_pred): y_pred_dist = _pairwise_distances(y_pred, squared=False) differences = tf.math.square(y_pred_dist - y_true) - differences = tf.linalg.band_part(differences, 0, -1) return differences diff --git a/aam/models/__init__.py b/aam/models/__init__.py index 2b5c6973a26dda5d88485624abf36df0d1925b48..e5da9c11cddd75f216ba25ab7cec8e42eb642ac3 100644 --- a/aam/models/__init__.py +++ b/aam/models/__init__.py @@ -1,12 +1,14 @@ from __future__ import annotations from .base_sequence_encoder import BaseSequenceEncoder +from .sequence_encoder import SequenceEncoder from .sequence_regressor import SequenceRegressor from .taxonomy_encoder import TaxonomyEncoder from .unifrac_encoder import UniFracEncoder __all__ = [ "BaseSequenceEncoder", + "SequenceEncoder", "SequenceRegressor", "TaxonomyEncoder", "UniFracEncoder", diff --git a/aam/models/base_sequence_encoder.py b/aam/models/base_sequence_encoder.py index 9cd901a9e73d951118dc6eb92e8b81f7bd4469fb..2af54efc1ad8bf5ec24204020f180ec7759d582b 100644 --- a/aam/models/base_sequence_encoder.py +++ b/aam/models/base_sequence_encoder.py @@ -7,7 +7,7 @@ from aam.layers import ( ASVEncoder, ) from aam.models.transformers import TransformerEncoder -from aam.utils import float_mask, masked_loss +from aam.utils import float_mask @tf.keras.saving.register_keras_serializable(package="BaseSequenceEncoder") @@ -25,6 +25,9 @@ class BaseSequenceEncoder(tf.keras.layers.Layer): nuc_attention_layers: int = 4, nuc_intermediate_size: int = 1024, intermediate_activation: str = "gelu", + is_16S: bool = True, + vocab_size: int = 6, + add_token: bool = True, **kwargs, ): super(BaseSequenceEncoder, self).__init__(**kwargs) @@ -39,29 +42,45 @@ class BaseSequenceEncoder(tf.keras.layers.Layer): self.nuc_attention_layers = nuc_attention_layers self.nuc_intermediate_size = nuc_intermediate_size self.intermediate_activation = intermediate_activation - self.nuc_loss = tf.keras.losses.SparseCategoricalCrossentropy( - ignore_class=0, from_logits=False, reduction="none" - ) - self.nuc_entropy = tf.keras.metrics.Mean() + self.is_16S = is_16S + self.vocab_size = vocab_size + self.add_token = add_token # layers used in model - self.asv_encoder = ASVEncoder( - max_bp, - nuc_attention_heads, - nuc_attention_layers, - 0.0, - nuc_intermediate_size, - intermediate_activation=self.intermediate_activation, - name="asv_encoder", - ) - self.nuc_logits = tf.keras.layers.Dense( - 6, use_bias=False, name="nuc_logits", dtype=tf.float32, activation="softmax" + if self.is_16S: + self.asv_encoder = ASVEncoder( + max_bp, + nuc_attention_heads, + nuc_attention_layers, + 0.0, + nuc_intermediate_size, + intermediate_activation=self.intermediate_activation, + add_token=self.add_token, + name="asv_encoder", + ) + else: + self.asv_embeddings = tf.keras.layers.Embedding( + self.vocab_size, + output_dim=self.embedding_dim, + embeddings_initializer=tf.keras.initializers.RandomNormal( + mean=0, stddev=self.embedding_dim**0.5 + ), + ) + self.asv_encoder = TransformerEncoder( + num_layers=self.sample_attention_layers, + num_attention_heads=self.sample_attention_heads, + intermediate_size=self.sample_intermediate_size, + activation=self.intermediate_activation, + dropout_rate=self.dropout_rate, + ) + + self.asv_pos = tfm.nlp.layers.PositionEmbedding( + self.token_limit + 5, + initializer=tf.keras.initializers.RandomNormal( + mean=0, stddev=self.embedding_dim**0.5 + ), ) - self.asv_scale = tf.keras.layers.Dense(self.embedding_dim, use_bias=False) - self.asv_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6) - self.asv_pos = tfm.nlp.layers.PositionEmbedding(self.token_limit + 5) - self.sample_encoder = TransformerEncoder( num_layers=self.sample_attention_layers, num_attention_heads=self.sample_attention_heads, @@ -69,25 +88,14 @@ class BaseSequenceEncoder(tf.keras.layers.Layer): activation=self.intermediate_activation, dropout_rate=self.dropout_rate, ) - self.sample_token = self.add_weight( - "sample_token", - [1, 1, self.embedding_dim], - dtype=tf.float32, - initializer=tf.keras.initializers.Zeros(), - trainable=True, - ) - self._base_alpha = self.add_weight( - name="base_alpha", - initializer=tf.keras.initializers.Zeros(), - trainable=True, - dtype=tf.float32, - ) - - self.linear_activation = tf.keras.layers.Activation("linear", dtype=tf.float32) - - @masked_loss(sparse_cat=True) - def _compute_nuc_loss(self, tokens: tf.Tensor, pred_tokens: tf.Tensor) -> tf.Tensor: - return self.nuc_loss(tokens, pred_tokens) + if self.add_token: + self.sample_token = self.add_weight( + "sample_token", + [1, 1, self.embedding_dim], + dtype=tf.float32, + initializer="glorot_uniform", + trainable=True, + ) def _add_sample_token(self, tensor: tf.Tensor) -> tf.Tensor: # add <SAMPLE> token empbedding @@ -102,45 +110,108 @@ class BaseSequenceEncoder(tf.keras.layers.Layer): return embeddings def _split_asvs(self, embeddings): - nuc_embeddings = embeddings[:, :, :-1, :] - nucleotides = self.nuc_logits(nuc_embeddings) + asv_embeddings = embeddings + if self.is_16S: + if self.add_token: + asv_embeddings = asv_embeddings[:, :, 0, :] + else: + asv_embeddings = tf.reduce_mean(asv_embeddings, axis=2) + else: + asv_embeddings = asv_embeddings[:, :, 0, :] - asv_embeddings = embeddings[:, :, 0, :] - asv_embeddings = self.asv_norm(asv_embeddings) asv_embeddings = asv_embeddings + self.asv_pos(asv_embeddings) - - return asv_embeddings, nucleotides + return asv_embeddings def call( + self, inputs: tf.Tensor, random_mask: bool = None, training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor]: + # need to cast inputs to int32 to avoid error + # because keras converts all inputs + # to float when calling build() + asv_input = tf.cast(inputs, dtype=tf.int32) + asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True)) + + if training and random_mask is not None: + asv_input = asv_input * tf.cast(random_mask, dtype=tf.int32) + + if self.is_16S: + embeddings = self.asv_encoder(asv_input, training=training) + else: + embeddings = self.asv_embeddings(asv_input) + asv_embeddings = self._split_asvs(embeddings) + + if self.add_token: + asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + sample_embeddings = self._add_sample_token(asv_embeddings) + else: + sample_embeddings = asv_embeddings + return sample_embeddings + + # def base_embeddings( + # self, inputs: tf.Tensor, training: bool = False + # ) -> tuple[tf.Tensor, tf.Tensor]: + # # need to cast inputs to int32 to avoid error + # # because keras converts all inputs + # # to float when calling build() + # asv_input = tf.cast(inputs, dtype=tf.int32) + + # embeddings = self.asv_encoder(asv_input, training=training) + # embeddings = self.asv_scale(embeddings) + # asv_embeddings, nucleotides = self._split_asvs(embeddings) + + # asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True)) + # padded_asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + + # # padded embeddings are the skip connection + # # normal asv embeddings continue through next block + # padded_asv_embeddings = tf.pad( + # asv_embeddings, [[0, 0], [1, 0], [0, 0]], constant_values=0 + # ) + + # sample_gated_embeddings = self._add_sample_token(asv_embeddings) + # sample_gated_embeddings = self.sample_encoder( + # sample_gated_embeddings, mask=padded_asv_mask, training=training + # ) + + # sample_embeddings = ( + # padded_asv_embeddings + sample_gated_embeddings * self._base_alpha + # ) + # return sample_embeddings + + def get_asv_embeddings( self, inputs: tf.Tensor, training: bool = False ) -> tuple[tf.Tensor, tf.Tensor]: + print("holyfucking shit") # need to cast inputs to int32 to avoid error # because keras converts all inputs # to float when calling build() asv_input = tf.cast(inputs, dtype=tf.int32) + asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True)) - embeddings = self.asv_encoder(asv_input, training=training) - embeddings = self.asv_scale(embeddings) + if self.is_16S: + embeddings = self.asv_encoder(asv_input, training=training) + embeddings = self.asv_scale(embeddings) + else: + embeddings = self.asv_embeddings(asv_input) asv_embeddings, nucleotides = self._split_asvs(embeddings) + return asv_embeddings + def asv_gradient( + self, inputs: tf.Tensor, asv_embeddings + ) -> tuple[tf.Tensor, tf.Tensor]: asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True)) - padded_asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) - # padded embeddings are the skip connection - # normal asv embeddings continue through next block - padded_asv_embeddings = tf.pad( - asv_embeddings, [[0, 0], [1, 0], [0, 0]], constant_values=0 - ) + if self.add_token: + asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + sample_embeddings = self._add_sample_token(asv_embeddings) + else: + sample_embeddings = asv_embeddings - sample_gated_embeddings = self._add_sample_token(asv_embeddings) sample_gated_embeddings = self.sample_encoder( - sample_gated_embeddings, mask=padded_asv_mask, training=training - ) - - sample_embeddings = ( - padded_asv_embeddings + sample_gated_embeddings * self._base_alpha + sample_embeddings, mask=asv_mask, training=False ) - return sample_embeddings, nucleotides + sample_embeddings = sample_embeddings + sample_gated_embeddings + return sample_embeddings def get_config(self): config = super(BaseSequenceEncoder, self).get_config() @@ -156,6 +227,9 @@ class BaseSequenceEncoder(tf.keras.layers.Layer): "nuc_attention_heads": self.nuc_attention_heads, "nuc_attention_layers": self.nuc_attention_layers, "nuc_intermediate_size": self.nuc_intermediate_size, + "is_16S": self.is_16S, + "vocab_size": self.vocab_size, + "add_token": self.add_token, } ) return config diff --git a/aam/models/conv_block.py b/aam/models/conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..915879c64bbf16b11f9024345923c640d4395b0b --- /dev/null +++ b/aam/models/conv_block.py @@ -0,0 +1,124 @@ +import tensorflow as tf + + +@tf.keras.saving.register_keras_serializable(package="ConvBlock") +class ConvBlock(tf.keras.layers.Layer): + def __init__( + self, + activation="gelu", + use_bias=True, + norm_epsilon=1e-6, + dilation=1, + **kwargs, + ): + super(ConvBlock, self).__init__(**kwargs) + self._activation = activation + self._use_bias = use_bias + self._norm_epsilon = norm_epsilon + self._dilation = dilation + + self.conv_norm = tf.keras.layers.LayerNormalization(epsilon=self._norm_epsilon) + + self.ff_norm = tf.keras.layers.LayerNormalization(epsilon=self._norm_epsilon) + + def build(self, input_shape): + seq_dim = input_shape[1] + emb_dim = input_shape[2] + self.conv1d = tf.keras.layers.Conv1D( + seq_dim, + kernel_size=3, + strides=1, + padding="same", + activation=self._activation, + dilation_rate=self._dilation, + ) + self.ff = tf.keras.layers.Dense( + emb_dim, use_bias=self._use_bias, activation=self._activation + ) + + def get_config(self): + config = { + "activation": self._activation, + "use_bias": self._use_bias, + "norm_epsilon": self._norm_epsilon, + "dilation": self._dilation, + } + base_config = super(ConvBlock, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, encoder_inputs, training=False): + """Return the output of the encoder. + + Args: + encoder_inputs: A tensor with shape `(batch_size, input_length, + hidden_size)`. + attention_mask: A mask for the encoder self-attention layer with shape + `(batch_size, input_length, input_length)`. + + Returns: + Output of encoder which is a `float32` tensor with shape + `(batch_size, input_length, hidden_size)`. + """ + conv_inputs = tf.transpose(encoder_inputs, perm=[0, 2, 1]) + conv_outputs = tf.transpose(self.conv1d(conv_inputs), perm=[0, 2, 1]) + conv_outputs = encoder_inputs + self.conv_norm(conv_outputs) + + ff_outputs = self.ff(conv_outputs) + output = conv_outputs + self.ff_norm(ff_outputs) + return output + + +@tf.keras.saving.register_keras_serializable(package="ConvModule") +class ConvModule(tf.keras.layers.Layer): + def __init__( + self, + layers, + activation="gelu", + use_bias=True, + norm_epsilon=1e-6, + **kwargs, + ): + super(ConvModule, self).__init__(**kwargs) + self.layers = layers + self._activation = activation + self._use_bias = use_bias + self._norm_epsilon = norm_epsilon + + self.conv_blocks = [] + dilation = 1 + for i in range(self.layers): + self.conv_blocks.append( + ConvBlock( + self._activation, + self._use_bias, + self._norm_epsilon, + dilation=(2**i) * dilation, + ) + ) + + def get_config(self): + config = { + "layers": self.layers, + "activation": self._activation, + "use_bias": self._use_bias, + "norm_epsilon": self._norm_epsilon, + } + base_config = super(ConvModule, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, encoder_inputs, training=False): + """Return the output of the encoder. + + Args: + encoder_inputs: A tensor with shape `(batch_size, input_length, + hidden_size)`. + attention_mask: A mask for the encoder self-attention layer with shape + `(batch_size, input_length, input_length)`. + + Returns: + Output of encoder which is a `float32` tensor with shape + `(batch_size, input_length, hidden_size)`. + """ + for layer_idx in range(self.layers): + encoder_inputs = self.conv_blocks[layer_idx](encoder_inputs) + return encoder_inputs diff --git a/aam/models/sequence_encoder.py b/aam/models/sequence_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..53af1d1f67d55a7c248096c9fe8b4b3985f5d81c --- /dev/null +++ b/aam/models/sequence_encoder.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from typing import Union + +import tensorflow as tf + +from aam.losses import PairwiseLoss +from aam.models.base_sequence_encoder import BaseSequenceEncoder +from aam.models.transformers import TransformerEncoder +from aam.optimizers.gradient_accumulator import GradientAccumulator +from aam.optimizers.loss_scaler import LossScaler +from aam.utils import apply_random_mask, float_mask + + +@tf.keras.saving.register_keras_serializable(package="SequenceEncoder") +class SequenceEncoder(tf.keras.Model): + def __init__( + self, + output_dim: int, + token_limit: int, + encoder_type: str, + dropout_rate: float = 0.0, + embedding_dim: int = 128, + attention_heads: int = 4, + attention_layers: int = 4, + intermediate_size: int = 1024, + intermediate_activation: str = "gelu", + max_bp: int = 150, + is_16S: bool = True, + vocab_size: int = 6, + add_token: bool = True, + asv_dropout_rate: float = 0.0, + accumulation_steps: int = 1, + **kwargs, + ): + super(SequenceEncoder, self).__init__(**kwargs) + self.output_dim = output_dim + self.token_limit = token_limit + self.encoder_type = encoder_type + self.dropout_rate = dropout_rate + self.embedding_dim = embedding_dim + self.attention_heads = attention_heads + self.attention_layers = attention_layers + self.intermediate_size = intermediate_size + self.intermediate_activation = intermediate_activation + self.max_bp = max_bp + self.is_16S = is_16S + self.vocab_size = vocab_size + self.add_token = add_token + self.asv_dropout_rate = asv_dropout_rate + self.accumulation_steps = accumulation_steps + + self._get_encoder_loss() + self.loss_tracker = tf.keras.metrics.Mean() + self.encoder_tracker = tf.keras.metrics.Mean() + + # layers used in model + self.base_encoder = BaseSequenceEncoder( + self.embedding_dim, + self.max_bp, + self.token_limit, + sample_attention_heads=self.attention_heads, + sample_attention_layers=self.attention_layers, + sample_intermediate_size=self.intermediate_size, + dropout_rate=self.dropout_rate, + nuc_attention_heads=2, + nuc_attention_layers=2, + nuc_intermediate_size=128, + intermediate_activation=self.intermediate_activation, + is_16S=self.is_16S, + vocab_size=self.vocab_size, + add_token=self.add_token, + name="base_encoder", + ) + + self.encoder = TransformerEncoder( + num_layers=self.attention_layers, + num_attention_heads=self.attention_heads, + intermediate_size=intermediate_size, + dropout_rate=self.dropout_rate, + activation=self.intermediate_activation, + name="encoder", + ) + + if self.encoder_type == "combined": + uni_out, faith_out, tax_out = self.output_dim + self.uni_ff = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(uni_out, dtype=tf.float32), + ] + ) + self.faith_ff = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(faith_out, dtype=tf.float32), + ] + ) + self.tax_ff = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(tax_out, dtype=tf.float32), + ] + ) + else: + self.encoder_ff = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(self.output_dim, dtype=tf.float32), + ] + ) + + self.gradient_accumulator = GradientAccumulator(self.accumulation_steps) + self.loss_scaler = LossScaler(self.gradient_accumulator.accum_steps) + + def _get_encoder_loss(self): + if self.encoder_type == "combined": + self._unifrac_loss = PairwiseLoss() + self._tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none") + self.encoder_loss = self._compute_combined_loss + self.extract_encoder_pred = self._combined_embeddigns + elif self.encoder_type == "unifrac": + self._unifrac_loss = PairwiseLoss() + self.encoder_loss = self._compute_unifrac_loss + self.extract_encoder_pred = self._unifrac_embeddings + elif self.encoder_type == "faith_pd": + self._unifrac_loss = tf.keras.losses.MeanSquaredError(reduction="none") + self.encoder_loss = self._compute_unifrac_loss + self.extract_encoder_pred = self._unifrac_embeddings + elif self.encoder_type == "taxonomy": + self._tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none") + self.encoder_loss = self._compute_tax_loss + self.extract_encoder_pred = self._taxonomy_embeddings + else: + raise Exception(f"invalid encoder encoder_type: {self.encoder_type}") + + def _combined_embeddigns(self, tensor, mask): + if self.add_token: + unifrac_pred = tensor[:, 0, :] + else: + mask = tf.cast(mask, dtype=tf.float32) + unifrac_pred = tf.reduce_sum(tensor * mask, axis=1) + unifrac_pred /= tf.reduce_sum(mask, axis=1) + + if self.add_token: + faith_pred = tensor[:, 0, :] + else: + mask = tf.cast(mask, dtype=tf.float32) + faith_pred = tf.reduce_sum(tensor * mask, axis=1) + faith_pred /= tf.reduce_sum(mask, axis=1) + + tax_pred = tensor + if self.add_token: + tax_pred = tax_pred[:, 1:, :] + + return [ + self.uni_ff(unifrac_pred), + self.faith_ff(faith_pred), + self.tax_ff(tax_pred), + ] + + def _unifrac_embeddings(self, tensor, mask): + if self.add_token: + encoder_pred = tensor[:, 0, :] + else: + mask = tf.cast(mask, dtype=tf.float32) + encoder_pred = tf.reduce_sum(tensor * mask, axis=1) + encoder_pred /= tf.reduce_sum(mask, axis=1) + encoder_pred = self.encoder_ff(encoder_pred) + return encoder_pred + + def _taxonomy_embeddings(self, tensor, mask): + tax_pred = tensor + if self.add_token: + tax_pred = tax_pred[:, 1:, :] + tax_pred = self.encoder_ff(tax_pred) + return tax_pred + + def _compute_combined_loss(self, y_true, preds): + uni_true, faith_true, tax_true = y_true + uni_pred, faith_pred, tax_pred = preds + + uni_loss = self._compute_unifrac_loss(uni_true, uni_pred) + faith_loss = tf.reduce_mean(tf.square(faith_true - faith_pred)) + tax_loss = self._compute_tax_loss(tax_true, tax_pred) + + return [uni_loss, faith_loss, tax_loss] + + def _compute_tax_loss( + self, + tax_tokens: tf.Tensor, + tax_pred: tf.Tensor, + ) -> tf.Tensor: + if isinstance(self.output_dim, (list, tuple)): + out_dim = self.output_dim[-1] + else: + out_dim = self.output_dim + y_true = tf.reshape(tax_tokens, [-1]) + y_pred = tf.reshape(tax_pred, [-1, out_dim]) + y_pred = tf.keras.activations.softmax(y_pred, axis=-1) + + mask = float_mask(y_true) > 0 + y_true = tf.one_hot(y_true, depth=out_dim) + + # smooth labels + y_true = y_true * 0.9 + 0.1 + + y_true = y_true[mask] + y_pred = y_pred[mask] + + loss = tf.reduce_mean(self._tax_loss(y_true, y_pred)) + return loss + + def _compute_unifrac_loss( + self, + y_true: tf.Tensor, + unifrac_embeddings: tf.Tensor, + ) -> tf.Tensor: + loss = self._unifrac_loss(y_true, unifrac_embeddings) + if self.encoder_type == "unifrac": + batch = tf.cast(tf.shape(y_true)[0], dtype=tf.float32) + loss = tf.reduce_sum(loss, axis=1, keepdims=True) / (batch - 1) + return tf.reduce_mean(loss) + + def _compute_encoder_loss( + self, + y_true: tf.Tensor, + encoder_embeddings: tf.Tensor, + ) -> tf.Tensor: + loss = self.encoder_loss(y_true, encoder_embeddings) + if self.encoder_type != "combined": + return tf.reduce_mean(loss) + else: + return loss + + def _compute_loss( + self, + model_inputs: tuple[tf.Tensor, tf.Tensor], + y_true: Union[tf.Tensor, tuple[tf.Tensor, tf.Tensor]], + outputs: tuple[tf.Tensor, tf.Tensor, tf.Tensor], + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + nuc_tokens, counts = model_inputs + embeddings, encoder_embeddings = outputs + return self._compute_encoder_loss(y_true, encoder_embeddings) + + def predict_step( + self, + data: Union[ + tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor], + tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], + ], + ): + inputs, y = data + embeddings, encoder_embeddings = self(inputs, training=False) + + return encoder_embeddings + + def train_step( + self, + data: Union[ + tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor], + tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], + ], + ): + if not self.gradient_accumulator.built: + self.gradient_accumulator.build(self.optimizer, self) + + inputs, y = data + y_target, encoder_target = y + with tf.GradientTape() as tape: + outputs = self(inputs, training=True) + encoder_loss = self._compute_loss(inputs, encoder_target, outputs) + if self.encoder_type == "combined": + scaled_losses = self.loss_scaler(encoder_loss) + else: + scaled_losses = self.loss_scaler([encoder_loss]) + loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0)) + + gradients = tape.gradient( + loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + self.gradient_accumulator.apply_gradients(gradients) + + self.loss_tracker.update_state(loss) + self.encoder_tracker.update_state(encoder_loss) + return { + "loss": self.loss_tracker.result(), + "encoder_loss": self.encoder_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } + + def test_step( + self, + data: Union[ + tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor], + tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], + ], + ): + inputs, y = data + y_target, encoder_target = y + outputs = self(inputs, training=False) + encoder_loss = self._compute_loss(inputs, encoder_target, outputs) + scaled_losses = self.loss_scaler([encoder_loss]) + loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0)) + + self.loss_tracker.update_state(loss) + self.encoder_tracker.update_state(encoder_loss) + return { + "loss": self.loss_tracker.result(), + "encoder_loss": self.encoder_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } + + def call( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + random_mask = None + if training and self.asv_dropout_rate > 0: + random_mask = apply_random_mask(count_mask, self.asv_dropout_rate) + + sample_embeddings = self.base_encoder( + tokens, random_mask=random_mask, training=training + ) + + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + encoder_gated_embeddings = self.encoder( + sample_embeddings, mask=count_attention_mask, training=training + ) + + encoder_pred = self.extract_encoder_pred(encoder_gated_embeddings, count_mask) + encoder_embeddings = sample_embeddings + encoder_gated_embeddings + return [encoder_embeddings, encoder_pred] + + def base_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor] + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings = self.base_encoder.base_embeddings(tokens) + + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + unifrac_gated_embeddings = self.unifrac_encoder( + sample_embeddings, mask=count_attention_mask + ) + unifrac_pred = unifrac_gated_embeddings[:, 0, :] + unifrac_pred = self.unifrac_ff(unifrac_pred) + + unifrac_embeddings = ( + sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha + ) + + return unifrac_embeddings + + def asv_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + asv_embeddings = self.base_encoder.asv_embeddings(tokens, training=training) + + return asv_embeddings + + def asv_gradient( + self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings) + + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + unifrac_gated_embeddings = self.unifrac_encoder( + sample_embeddings, mask=count_attention_mask + ) + unifrac_pred = unifrac_gated_embeddings[:, 0, :] + unifrac_pred = self.unifrac_ff(unifrac_pred) + + unifrac_embeddings = ( + sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha + ) + + return unifrac_embeddings + + def get_config(self): + config = super(SequenceEncoder, self).get_config() + config.update( + { + "output_dim": self.output_dim, + "token_limit": self.token_limit, + "encoder_type": self.encoder_type, + "dropout_rate": self.dropout_rate, + "embedding_dim": self.embedding_dim, + "attention_heads": self.attention_heads, + "attention_layers": self.attention_layers, + "intermediate_size": self.intermediate_size, + "intermediate_activation": self.intermediate_activation, + "max_bp": self.max_bp, + "is_16S": self.is_16S, + "vocab_size": self.vocab_size, + "add_token": self.add_token, + "asv_dropout_rate": self.asv_dropout_rate, + "accumulation_steps": self.accumulation_steps, + } + ) + return config diff --git a/aam/models/sequence_regressor.py b/aam/models/sequence_regressor.py index 0e5d9140ab441eacec85603c076c70e667e193dc..7c22d1f7411397fbe73d0dfe834b02e726425306 100644 --- a/aam/models/sequence_regressor.py +++ b/aam/models/sequence_regressor.py @@ -5,10 +5,12 @@ from typing import Optional, Union import tensorflow as tf import tensorflow_models as tfm -from aam.models.taxonomy_encoder import TaxonomyEncoder +# from aam.models.unifrac_encoder import UniFracEncoder +from aam.models.sequence_encoder import SequenceEncoder from aam.models.transformers import TransformerEncoder -from aam.models.unifrac_encoder import UniFracEncoder -from aam.utils import float_mask, masked_loss +from aam.optimizers.gradient_accumulator import GradientAccumulator +from aam.optimizers.loss_scaler import LossScaler +from aam.utils import float_mask @tf.keras.saving.register_keras_serializable(package="SequenceRegressor") @@ -16,30 +18,36 @@ class SequenceRegressor(tf.keras.Model): def __init__( self, token_limit: int, - num_classes: Optional[int] = None, + base_output_dim: Optional[int] = None, shift: float = 0.0, scale: float = 1.0, dropout_rate: float = 0.0, - num_tax_levels: Optional[int] = None, embedding_dim: int = 128, attention_heads: int = 4, attention_layers: int = 4, intermediate_size: int = 1024, intermediate_activation: str = "relu", - base_model: Union[str, TaxonomyEncoder, UniFracEncoder] = "taxonomy", + base_model: str = "unifrac", freeze_base: bool = False, penalty: float = 1.0, nuc_penalty: float = 1.0, max_bp: int = 150, + is_16S: bool = True, + vocab_size: int = 6, + out_dim: int = 1, + classifier: bool = False, + add_token: bool = True, + class_weights: list = None, + asv_dropout_rate: float = 0.0, + accumulation_steps: int = 1, **kwargs, ): super(SequenceRegressor, self).__init__(**kwargs) self.token_limit = token_limit - self.num_classes = num_classes + self.base_output_dim = base_output_dim self.shift = shift self.scale = scale self.dropout_rate = dropout_rate - self.num_tax_levels = num_tax_levels self.embedding_dim = embedding_dim self.attention_heads = attention_heads self.attention_layers = attention_layers @@ -49,66 +57,56 @@ class SequenceRegressor(tf.keras.Model): self.penalty = penalty self.nuc_penalty = nuc_penalty self.max_bp = max_bp + self.is_16S = is_16S + self.vocab_size = vocab_size + self.out_dim = out_dim + self.classifier = classifier + self.add_token = add_token + self.class_weights = class_weights + self.asv_dropout_rate = asv_dropout_rate + self.accumulation_steps = accumulation_steps self.loss_tracker = tf.keras.metrics.Mean() # layers used in model + self.combined_base = False if isinstance(base_model, str): - if base_model == "taxonomy": - self.base_model = TaxonomyEncoder( - num_tax_levels=self.num_tax_levels, - token_limit=self.token_limit, - dropout_rate=self.dropout_rate, - embedding_dim=self.embedding_dim, - attention_heads=self.attention_heads, - attention_layers=self.attention_layers, - intermediate_size=self.intermediate_size, - intermediate_activation=self.intermediate_activation, - max_bp=self.max_bp, - ) - elif base_model == "unifrac": - self.base_model = UniFracEncoder( - self.token_limit, - dropout_rate=self.dropout_rate, - embedding_dim=self.embedding_dim, - attention_heads=self.attention_heads, - attention_layers=self.attention_layers, - intermediate_size=self.intermediate_size, - intermediate_activation=self.intermediate_activation, - max_bp=self.max_bp, - ) - else: - raise Exception("Invalid base model option.") + if base_model == "combined": + self.combined_base = True + + self.base_model = SequenceEncoder( + output_dim=self.base_output_dim, + token_limit=self.token_limit, + encoder_type=base_model, + dropout_rate=self.dropout_rate, + embedding_dim=self.embedding_dim, + attention_heads=self.attention_heads, + attention_layers=self.attention_layers, + intermediate_size=self.intermediate_size, + intermediate_activation=self.intermediate_activation, + max_bp=self.max_bp, + is_16S=self.is_16S, + vocab_size=self.vocab_size, + add_token=self.add_token, + asv_dropout_rate=self.asv_dropout_rate, + accumulation_steps=self.accumulation_steps, + ) else: - if not isinstance(base_model, (TaxonomyEncoder, UniFracEncoder)): - raise Exception(f"Unsupported base model of type {type(base_model)}") self.base_model = base_model - if isinstance(self.base_model, TaxonomyEncoder): - self.base_losses = {"base_loss": self.base_model._compute_tax_loss} + self.base_losses = {"base_loss": self.base_model._compute_encoder_loss} + if not self.combined_base: self.base_metrics = { - "base_loss": ("tax_entropy", self.base_model.tax_tracker) + "base_loss": ["encoder_loss", self.base_model.encoder_tracker] } else: - self.base_losses = {"base_loss": self.base_model._compute_unifrac_loss} - self.base_metrics = { - "base_loss": ["unifrac_mse", self.base_model.unifrac_tracker] - } - - self.base_losses.update({"nuc_entropy": self.base_model._compute_nuc_loss}) - self.base_metrics.update( - {"nuc_entropy": ["nuc_entropy", self.base_model.base_encoder.nuc_entropy]} - ) + self.uni_tracker = tf.keras.metrics.Mean() + # self.faith_tracker = tf.keras.metrics.Mean() + self.tax_tracker = tf.keras.metrics.Mean() if self.freeze_base: print("Freezing base model...") self.base_model.trainable = False - self._count_alpha = self.add_weight( - name="count_alpha", - initializer=tf.keras.initializers.Zeros(), - trainable=True, - dtype=tf.float32, - ) self.count_encoder = TransformerEncoder( num_layers=self.attention_layers, num_attention_heads=self.attention_heads, @@ -119,8 +117,14 @@ class SequenceRegressor(tf.keras.Model): self.count_pos = tfm.nlp.layers.PositionEmbedding( self.token_limit + 5, dtype=tf.float32 ) - self.count_out = tf.keras.layers.Dense(1, use_bias=False, dtype=tf.float32) - self.count_activation = tf.keras.layers.Activation("linear", dtype=tf.float32) + self.count_out = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(1, dtype=tf.float32), + ] + ) self.count_loss = tf.keras.losses.MeanSquaredError(reduction="none") self.count_tracker = tf.keras.metrics.Mean() @@ -131,15 +135,28 @@ class SequenceRegressor(tf.keras.Model): dropout_rate=self.dropout_rate, activation=self.intermediate_activation, ) + self.target_tracker = tf.keras.metrics.Mean() + if not self.classifier: + self.metric_tracker = tf.keras.metrics.MeanAbsoluteError() + self.metric_string = "mae" + else: + self.metric_tracker = tf.keras.metrics.SparseCategoricalAccuracy() + self.metric_string = "accuracy" + self.target_ff = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.embedding_dim, activation="gelu", dtype=tf.float32 + ), + tf.keras.layers.Dense(self.out_dim, dtype=tf.float32), + ] + ) - self.target_ff = tf.keras.layers.Dense(1, use_bias=False, dtype=tf.float32) - self.metric_tracker = tf.keras.metrics.MeanAbsoluteError() - self.metric_string = "mae" - self.target_activation = tf.keras.layers.Activation("linear", dtype=tf.float32) self.loss_metrics = sorted( ["loss", "target_loss", "count_mse", self.metric_string] ) + self.gradient_accumulator = GradientAccumulator(self.accumulation_steps) + self.loss_scaler = LossScaler(self.gradient_accumulator.accum_steps) def evaluate_metric(self, dataset, metric, **kwargs): metric_index = self.loss_metrics.index(metric) @@ -147,15 +164,27 @@ class SequenceRegressor(tf.keras.Model): return evaluated_metrics[metric_index] def _compute_target_loss(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: - return tf.reduce_mean(self.loss(y_true, y_pred)) + if not self.classifier: + loss = tf.square(y_true - y_pred) + return tf.reduce_mean(loss) + + y_true = tf.cast(y_true, dtype=tf.int32) + y_true = tf.reshape(y_true, shape=[-1]) + weights = tf.gather(self.class_weights, y_true) + y_true = tf.one_hot(y_true, self.out_dim) + y_pred = tf.reshape(y_pred, shape=[-1, self.out_dim]) + y_pred = tf.keras.activations.softmax(y_pred, axis=-1) + loss = self.loss(y_true, y_pred) # * weights + return tf.reduce_mean(loss) - @masked_loss(sparse_cat=False) def _compute_count_loss( self, counts: tf.Tensor, count_pred: tf.Tensor ) -> tf.Tensor: relative_counts = self._relative_abundance(counts) loss = tf.square(relative_counts - count_pred) - return tf.squeeze(loss, axis=-1) + mask = float_mask(counts) + loss = tf.reduce_sum(loss * mask, axis=1) / tf.reduce_sum(mask, axis=1) + return tf.reduce_mean(loss) def _compute_loss( self, @@ -165,36 +194,26 @@ class SequenceRegressor(tf.keras.Model): tuple[tf.Tensor, tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], ], - sample_weights: Optional[tf.Tensor] = None, + train_step: bool = True, ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: nuc_tokens, counts = model_inputs y_target, base_target = y_true - target_embeddings, count_pred, y_pred, base_pred, nuc_pred = outputs - + target_embeddings, count_pred, y_pred, base_pred = outputs target_loss = self._compute_target_loss(y_target, y_pred) count_loss = self._compute_count_loss(counts, count_pred) - loss = target_loss + count_loss + base_loss = 0 if not self.freeze_base: - base_loss = ( - self.base_losses["base_loss"](base_target, base_pred) * self.penalty - ) - nuc_loss = ( - self.base_losses["nuc_entropy"](nuc_tokens, nuc_pred) * self.nuc_penalty - ) - loss = loss + nuc_loss + base_loss - else: - base_loss = 0 - nuc_loss = 0 + if self.combined_base: + uni_loss, faith_loss, tax_loss = self.base_losses["base_loss"]( + base_target, base_pred + ) + return (target_loss, count_loss, uni_loss, faith_loss, tax_loss) + else: + base_loss = self.base_losses["base_loss"](base_target, base_pred) - return ( - loss, - target_loss, - count_loss, - base_loss, - nuc_loss, - ) + return (target_loss, count_loss, base_loss) def _compute_metric( self, @@ -206,10 +225,15 @@ class SequenceRegressor(tf.keras.Model): ): y_true, base_target = y_true - target_embeddings, count_pred, y_pred, base_pred, nuc_pred = outputs - y_true = y_true * self.scale + self.shift - y_pred = y_pred * self.scale + self.shift - self.metric_tracker.update_state(y_true, y_pred) + target_embeddings, count_pred, y_pred, base_pred = outputs + if not self.classifier: + y_true = y_true * self.scale + self.shift + y_pred = y_pred * self.scale + self.shift + self.metric_tracker.update_state(y_true, y_pred) + else: + y_true = tf.cast(y_true, dtype=tf.int32) + y_pred = tf.keras.activations.softmax(y_pred) + self.metric_tracker.update_state(y_true, y_pred) def predict_step( self, @@ -219,12 +243,14 @@ class SequenceRegressor(tf.keras.Model): ], ): inputs, (y_true, _) = data - target_embeddings, count_pred, y_pred, base_pred, nuc_pred = self( - inputs, training=False - ) + target_embeddings, count_pred, y_pred, base_pred = self(inputs, training=False) - y_true = y_true * self.scale + self.shift - y_pred = y_pred * self.scale + self.shift + if not self.classifier: + y_true = y_true * self.scale + self.shift + y_pred = y_pred * self.scale + self.shift + else: + y_true = tf.cast(y_true, dtype=tf.int32) + y_pred = tf.argmax(tf.keras.activations.softmax(y_pred), axis=-1) return y_pred, y_true def train_step( @@ -234,33 +260,64 @@ class SequenceRegressor(tf.keras.Model): tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], ], ): + if not self.gradient_accumulator.built: + self.gradient_accumulator.build(self.optimizer, self) + inputs, y = data with tf.GradientTape() as tape: outputs = self(inputs, training=True) - loss, target_loss, count_mse, base_loss, nuc_loss = self._compute_loss( - inputs, y, outputs - ) - - gradients = tape.gradient(loss, self.trainable_variables) - self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + if self.combined_base: + target_loss, count_mse, uni_loss, faith_loss, tax_loss = ( + self._compute_loss(inputs, y, outputs, train_step=True) + ) + scaled_losses = self.loss_scaler( + [target_loss, count_mse, uni_loss, tax_loss] + ) + else: + target_loss, count_mse, base_loss = self._compute_loss( + inputs, y, outputs, train_step=True + ) + scaled_losses = self.loss_scaler([target_loss, count_mse, base_loss]) + loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0)) + gradients = tape.gradient( + loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + self.gradient_accumulator.apply_gradients(gradients) self.loss_tracker.update_state(loss) self.target_tracker.update_state(target_loss) self.count_tracker.update_state(count_mse) - base_loss_key, base_loss_metric = self.base_metrics["base_loss"] - base_loss_metric.update_state(base_loss) - nuc_entropy_key, nuc_entropy_metric = self.base_metrics["nuc_entropy"] - nuc_entropy_metric.update_state(nuc_loss) + self._compute_metric(y, outputs) - return { - "loss": self.loss_tracker.result(), - "target_loss": self.target_tracker.result(), - "count_mse": self.count_tracker.result(), - base_loss_key: base_loss_metric.result(), - nuc_entropy_key: nuc_entropy_metric.result(), - self.metric_string: self.metric_tracker.result(), - } + if self.combined_base: + self.uni_tracker.update_state(uni_loss) + # self.faith_tracker.update_state(faith_loss) + self.tax_tracker.update_state(tax_loss) + return { + "loss": self.loss_tracker.result(), + "target_loss": self.target_tracker.result(), + "count_mse": self.count_tracker.result(), + # base_loss_key: base_loss_metric.result(), + "uni_loss": self.uni_tracker.result(), + # "faith_loss": self.faith_tracker.result(), + "tax_loss": self.tax_tracker.result(), + self.metric_string: self.metric_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } + else: + base_loss_key, base_loss_metric = self.base_metrics["base_loss"] + base_loss_metric.update_state(base_loss) + return { + "loss": self.loss_tracker.result(), + "target_loss": self.target_tracker.result(), + "count_mse": self.count_tracker.result(), + base_loss_key: base_loss_metric.result(), + self.metric_string: self.metric_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } def test_step( self, @@ -271,27 +328,52 @@ class SequenceRegressor(tf.keras.Model): ): inputs, y = data - outputs = self(inputs, training=False) - loss, target_loss, count_mse, base_loss, nuc_loss = self._compute_loss( - inputs, y, outputs - ) + outputs = self(inputs, training=True) + if self.combined_base: + target_loss, count_mse, uni_loss, faith_loss, tax_loss = self._compute_loss( + inputs, y, outputs, train_step=True + ) + scaled_losses = self.loss_scaler( + [target_loss, count_mse, uni_loss, tax_loss] + ) + else: + target_loss, count_mse, base_loss = self._compute_loss( + inputs, y, outputs, train_step=True + ) + scaled_losses = self.loss_scaler([target_loss, count_mse, base_loss]) + loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0)) self.loss_tracker.update_state(loss) self.target_tracker.update_state(target_loss) self.count_tracker.update_state(count_mse) - base_loss_key, base_loss_metric = self.base_metrics["base_loss"] - base_loss_metric.update_state(base_loss) - nuc_entropy_key, nuc_entropy_metric = self.base_metrics["nuc_entropy"] - nuc_entropy_metric.update_state(nuc_loss) + self._compute_metric(y, outputs) - return { - "loss": self.loss_tracker.result(), - "target_loss": self.target_tracker.result(), - "count_mse": self.count_tracker.result(), - base_loss_key: base_loss_metric.result(), - nuc_entropy_key: nuc_entropy_metric.result(), - self.metric_string: self.metric_tracker.result(), - } + if self.combined_base: + self.uni_tracker.update_state(uni_loss) + # self.faith_tracker.update_state(faith_loss) + self.tax_tracker.update_state(tax_loss) + return { + "loss": self.loss_tracker.result(), + "target_loss": self.target_tracker.result(), + "count_mse": self.count_tracker.result(), + # base_loss_key: base_loss_metric.result(), + "uni_loss": self.uni_tracker.result(), + # "faith_loss": self.faith_tracker.result(), + "tax_loss": self.tax_tracker.result(), + self.metric_string: self.metric_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } + else: + base_loss_key, base_loss_metric = self.base_metrics["base_loss"] + base_loss_metric.update_state(base_loss) + return { + "loss": self.loss_tracker.result(), + "target_loss": self.target_tracker.result(), + "count_mse": self.count_tracker.result(), + base_loss_key: base_loss_metric.result(), + self.metric_string: self.metric_tracker.result(), + "learning_rate": self.optimizer.learning_rate, + } def _relative_abundance(self, counts: tf.Tensor) -> tf.Tensor: counts = tf.cast(counts, dtype=tf.float32) @@ -306,11 +388,14 @@ class SequenceRegressor(tf.keras.Model): attention_mask: Optional[tf.Tensor] = None, training: bool = False, ) -> tf.Tensor: - count_embeddings = tensor + self.count_pos(tensor) * relative_abundances + count_embeddings = (tensor + self.count_pos(tensor)) * relative_abundances count_embeddings = self.count_encoder( count_embeddings, mask=attention_mask, training=training ) - count_pred = count_embeddings[:, 1:, :] + count_pred = count_embeddings + if self.add_token: + count_pred = count_pred[:, 1:, :] + count_pred = self.count_out(count_pred) return count_embeddings, count_pred @@ -323,7 +408,13 @@ class SequenceRegressor(tf.keras.Model): target_embeddings = self.target_encoder( tensor, mask=attention_mask, training=training ) - target_out = target_embeddings[:, 0, :] + if self.add_token: + target_out = target_embeddings[:, 0, :] + else: + mask = tf.cast(attention_mask, dtype=tf.float32) + target_out = tf.reduce_sum(target_embeddings * mask, axis=1) + target_out /= tf.reduce_sum(mask, axis=1) + target_out = self.target_ff(target_out) return target_embeddings, target_out @@ -342,12 +433,13 @@ class SequenceRegressor(tf.keras.Model): rel_abundance = self._relative_abundance(counts) # account for <SAMPLE> token - count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) - rel_abundance = tf.pad( - rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 - ) + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + rel_abundance = tf.pad( + rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 + ) count_attention_mask = count_mask - base_embeddings, base_pred, nuc_embeddings = self.base_model( + base_embeddings, base_pred = self.base_model( (tokens, counts), training=training ) @@ -357,40 +449,161 @@ class SequenceRegressor(tf.keras.Model): attention_mask=count_attention_mask, training=training, ) - count_embeddings = base_embeddings + count_gated_embeddings * self._count_alpha + count_embeddings = base_embeddings + count_gated_embeddings + # count_embeddings = count_gated_embeddings target_embeddings, target_out = self._compute_target_embeddings( count_embeddings, attention_mask=count_attention_mask, training=training ) - return ( - target_embeddings, - self.count_activation(count_pred), - self.target_activation(target_out), - base_pred, - nuc_embeddings, + return (target_embeddings, count_pred, target_out, base_pred) + + def base_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor] + ) -> Union[ + tuple[tf.Tensor, tf.Tensor, tf.Tensor], + tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], + ]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + count_mask = float_mask(counts, dtype=tf.int32) + rel_abundance = self._relative_abundance(counts) + + # account for <SAMPLE> token + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + rel_abundance = tf.pad( + rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 + ) + base_embeddings = self.base_model.base_embeddings((tokens, counts)) + + return base_embeddings + + def base_gradient( + self, inputs: tuple[tf.Tensor, tf.Tensor], base_embeddings + ) -> Union[ + tuple[tf.Tensor, tf.Tensor, tf.Tensor], + tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], + ]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + count_mask = float_mask(counts, dtype=tf.int32) + rel_abundance = self._relative_abundance(counts) + + # account for <SAMPLE> token + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + rel_abundance = tf.pad( + rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 + ) + count_attention_mask = count_mask + + count_gated_embeddings, count_pred = self._compute_count_embeddings( + base_embeddings, + rel_abundance, + attention_mask=count_attention_mask, + ) + # count_embeddings = base_embeddings + count_gated_embeddings * self._count_alpha + count_embeddings = count_gated_embeddings + + target_embeddings, target_out = self._compute_target_embeddings( + count_embeddings, attention_mask=count_attention_mask + ) + + return self.target_activation(target_out) + + def asv_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> Union[ + tuple[tf.Tensor, tf.Tensor, tf.Tensor], + tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], + ]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + count_mask = float_mask(counts, dtype=tf.int32) + rel_abundance = self._relative_abundance(counts) + + # account for <SAMPLE> token + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + rel_abundance = tf.pad( + rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 + ) + asv_embeddings = self.base_model.asv_embeddings( + (tokens, counts), training=False + ) + + return asv_embeddings + + def asv_gradient(self, inputs, asv_embeddings): + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + count_mask = float_mask(counts, dtype=tf.int32) + rel_abundance = self._relative_abundance(counts) + + # account for <SAMPLE> token + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + rel_abundance = tf.pad( + rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1 + ) + count_attention_mask = count_mask + base_embeddings = self.base_model.asv_gradient( + (tokens, counts), asv_embeddings=asv_embeddings + ) + + count_gated_embeddings, count_pred = self._compute_count_embeddings( + base_embeddings, + rel_abundance, + attention_mask=count_attention_mask, + training=False, + ) + # count_embeddings = base_embeddings + count_gated_embeddings + count_embeddings = count_gated_embeddings + + target_embeddings, target_out = self._compute_target_embeddings( + count_embeddings, attention_mask=count_attention_mask, training=False ) + return self.target_activation(target_out) + def get_config(self): config = super(SequenceRegressor, self).get_config() config.update( { "token_limit": self.token_limit, - "base_model": tf.keras.saving.serialize_keras_object(self.base_model), - "num_classes": self.num_classes, + "base_output_dim": self.base_output_dim, "shift": self.shift, "scale": self.scale, "dropout_rate": self.dropout_rate, - "num_tax_levels": self.num_tax_levels, "embedding_dim": self.embedding_dim, "attention_heads": self.attention_heads, "attention_layers": self.attention_layers, "intermediate_size": self.intermediate_size, "intermediate_activation": self.intermediate_activation, + "base_model": tf.keras.saving.serialize_keras_object(self.base_model), "freeze_base": self.freeze_base, "penalty": self.penalty, "nuc_penalty": self.nuc_penalty, "max_bp": self.max_bp, + "is_16S": self.is_16S, + "vocab_size": self.vocab_size, + "out_dim": self.out_dim, + "classifier": self.classifier, + "add_token": self.add_token, + "class_weights": self.class_weights, + "asv_dropout_rate": self.asv_dropout_rate, + "accumulation_steps": self.accumulation_steps, } ) return config diff --git a/aam/models/taxonomy_encoder.py b/aam/models/taxonomy_encoder.py index a83cfc265af9fab40bac125d18008b4833e9de2c..1ba4d27d206e6db8349b8eac48375f605eafce94 100644 --- a/aam/models/taxonomy_encoder.py +++ b/aam/models/taxonomy_encoder.py @@ -6,7 +6,9 @@ import tensorflow as tf from aam.models.base_sequence_encoder import BaseSequenceEncoder from aam.models.transformers import TransformerEncoder -from aam.utils import float_mask, masked_loss +from aam.optimizers.gradient_accumulator import GradientAccumulator +from aam.optimizers.loss_scaler import LossScaler +from aam.utils import apply_random_mask, float_mask @tf.keras.saving.register_keras_serializable(package="TaxonomyEncoder") @@ -22,6 +24,12 @@ class TaxonomyEncoder(tf.keras.Model): intermediate_size: int = 1024, intermediate_activation: str = "gelu", max_bp: int = 150, + include_alpha: bool = True, + is_16S: bool = True, + vocab_size: int = 6, + add_token: bool = True, + asv_dropout_rate: float = 0.0, + accumulation_steps: int = 1, **kwargs, ): super(TaxonomyEncoder, self).__init__(**kwargs) @@ -35,12 +43,16 @@ class TaxonomyEncoder(tf.keras.Model): self.intermediate_size = intermediate_size self.intermediate_activation = intermediate_activation self.max_bp = max_bp - + self.include_alpha = include_alpha + self.is_16S = is_16S + self.vocab_size = vocab_size + self.add_token = add_token + self.asv_dropout_rate = asv_dropout_rate + self.accumulation_steps = accumulation_steps self.loss_tracker = tf.keras.metrics.Mean() - self.tax_loss = tf.keras.losses.SparseCategoricalCrossentropy( - ignore_class=0, from_logits=True, reduction="none" - ) - self.tax_tracker = tf.keras.metrics.Mean() + # self.tax_loss = tf.keras.losses.CategoricalFocalCrossentropy(reduction="none") + self.tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none") + self.encoder_tracker = tf.keras.metrics.Mean() # layers used in model self.base_encoder = BaseSequenceEncoder( @@ -55,16 +67,12 @@ class TaxonomyEncoder(tf.keras.Model): nuc_attention_layers=3, nuc_intermediate_size=128, intermediate_activation=self.intermediate_activation, + is_16S=self.is_16S, + vocab_size=self.vocab_size, + add_token=self.add_token, name="base_encoder", ) - self._tax_alpha = self.add_weight( - name="tax_alpha", - initializer=tf.keras.initializers.Zeros(), - trainable=True, - dtype=tf.float32, - ) - self.tax_encoder = TransformerEncoder( num_layers=self.attention_layers, num_attention_heads=self.attention_heads, @@ -75,13 +83,12 @@ class TaxonomyEncoder(tf.keras.Model): ) self.tax_level_logits = tf.keras.layers.Dense( - self.num_tax_levels, - use_bias=False, - dtype=tf.float32, - name="tax_level_logits", + self.num_tax_levels, dtype=tf.float32 ) self.loss_metrics = sorted(["loss", "target_loss", "count_mse"]) + self.gradient_accumulator = GradientAccumulator(self.accumulation_steps) + self.loss_scaler = LossScaler() def evaluate_metric(self, dataset, metric, **kwargs): metric_index = self.loss_metrics.index(metric) @@ -91,15 +98,22 @@ class TaxonomyEncoder(tf.keras.Model): def _compute_nuc_loss(self, nuc_tokens, nuc_pred): return self.base_encoder._compute_nuc_loss(nuc_tokens, nuc_pred) - @masked_loss(sparse_cat=True) - def _compute_tax_loss( + def _compute_encoder_loss( self, tax_tokens: tf.Tensor, tax_pred: tf.Tensor, sample_weights: Optional[tf.Tensor] = None, ) -> tf.Tensor: - loss = self.tax_loss(tax_tokens, tax_pred) + y_true = tf.reshape(tax_tokens, [-1]) + y_pred = tf.reshape(tax_pred, [-1, self.num_tax_levels]) + y_pred = tf.keras.activations.softmax(y_pred, axis=-1) + mask = float_mask(y_true) > 0 + y_true = tf.one_hot(y_true, depth=self.num_tax_levels) + y_true = y_true[mask] + y_pred = y_pred[mask] + + loss = tf.reduce_mean(self.tax_loss(y_true, y_pred)) return loss def _compute_loss( @@ -111,14 +125,9 @@ class TaxonomyEncoder(tf.keras.Model): ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: nuc_tokens, counts = model_inputs taxonomy_embeddings, tax_pred, nuc_pred = outputs - tax_loss = self._compute_tax_loss( - y_true, tax_pred, sample_weights=sample_weights - ) - - nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) - + tax_loss = self._compute_encoder_loss(y_true, tax_pred) + nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) * 0.0 loss = tax_loss + nuc_loss - return (loss, tax_loss, nuc_loss) def predict_step( @@ -140,22 +149,32 @@ class TaxonomyEncoder(tf.keras.Model): tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], ], ): - inputs, y = data + if not self.gradient_accumulator.built: + self.gradient_accumulator.build(self.optimizer, self) + inputs, y = data + y_target, tax_target = y with tf.GradientTape() as tape: outputs = self(inputs, training=True) - loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs) - - gradients = tape.gradient(loss, self.trainable_variables) - self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + loss, tax_loss, nuc_loss = self._compute_loss(inputs, tax_target, outputs) + scaled_losses = self.loss_scaler([tax_loss]) + loss = tf.reduce_sum(tf.stack(scaled_losses, axis=0)) + + gradients = tape.gradient( + loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + self.gradient_accumulator.apply_gradients(gradients) self.loss_tracker.update_state(loss) - self.tax_tracker.update_state(tax_loss) + self.encoder_tracker.update_state(tax_loss) self.base_encoder.nuc_entropy.update_state(nuc_loss) return { "loss": self.loss_tracker.result(), - "tax_entropy": self.tax_tracker.result(), + "tax_entropy": self.encoder_tracker.result(), "nuc_entropy": self.base_encoder.nuc_entropy.result(), + "learning_rate": self.optimizer.learning_rate, } def test_step( @@ -166,17 +185,18 @@ class TaxonomyEncoder(tf.keras.Model): ], ): inputs, y = data - + y_target, tax_target = y outputs = self(inputs, training=False) - loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs) + loss, tax_loss, nuc_loss = self._compute_loss(inputs, tax_target, outputs) self.loss_tracker.update_state(loss) - self.tax_tracker.update_state(tax_loss) + self.encoder_tracker.update_state(tax_loss) self.base_encoder.nuc_entropy.update_state(nuc_loss) return { "loss": self.loss_tracker.result(), - "tax_entropy": self.tax_tracker.result(), + "tax_entropy": self.encoder_tracker.result(), "nuc_entropy": self.base_encoder.nuc_entropy.result(), + "learning_rate": self.optimizer.learning_rate, } def call( @@ -187,23 +207,100 @@ class TaxonomyEncoder(tf.keras.Model): tokens = tf.cast(tokens, dtype=tf.int32) counts = tf.cast(counts, dtype=tf.int32) - sample_embeddings, nuc_embeddings = self.base_encoder(tokens, training=training) + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + random_mask = None + if training and self.asv_dropout_rate > 0: + random_mask = apply_random_mask(count_mask, self.asv_dropout_rate) + + sample_embeddings, nuc_embeddings = self.base_encoder( + tokens, random_mask=random_mask, training=training + ) + + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + tax_gated_embeddings = self.tax_encoder( + sample_embeddings, mask=count_attention_mask, training=training + ) + tax_pred = tax_gated_embeddings + if self.add_token: + tax_pred = tax_pred[:, 1:, :] + tax_pred = self.tax_level_logits(tax_pred) + + tax_embeddings = sample_embeddings + tax_gated_embeddings + return [tax_embeddings, tax_pred, nuc_embeddings] + + def base_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings, nuc_embeddings, asv_embeddings = ( + self.base_encoder.base_embeddings(tokens, training=training) + ) # account for <SAMPLE> token count_mask = float_mask(counts, dtype=tf.int32) count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) count_attention_mask = count_mask - tax_gated_embeddings = self.tax_encoder( + unifrac_gated_embeddings = self.unifrac_encoder( sample_embeddings, mask=count_attention_mask, training=training ) + unifrac_pred = unifrac_gated_embeddings[:, 0, :] + unifrac_pred = self.unifrac_ff(unifrac_pred) + + unifrac_embeddings = ( + sample_embeddings + unifrac_gated_embeddings * self._tax_alpha + ) + + return [unifrac_embeddings, unifrac_pred, nuc_embeddings, asv_embeddings] + + def asv_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + asv_embeddings = self.base_encoder.get_asv_embeddings(tokens, training=training) - tax_pred = tax_gated_embeddings[:, 1:, :] + return asv_embeddings + + def asv_gradient( + self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings) + + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + tax_gated_embeddings = self.tax_encoder( + sample_embeddings, mask=count_attention_mask, training=False + ) + tax_pred = tax_gated_embeddings + if self.add_token: + tax_pred = tax_pred[:, 1:, :] tax_pred = self.tax_level_logits(tax_pred) - tax_embeddings = sample_embeddings + tax_gated_embeddings * self._tax_alpha + # tax_embeddings = sample_embeddings + tax_gated_embeddings + tax_embeddings = tax_gated_embeddings - return [tax_embeddings, tax_pred, nuc_embeddings] + return tax_embeddings def get_config(self): config = super(TaxonomyEncoder, self).get_config() @@ -218,6 +315,11 @@ class TaxonomyEncoder(tf.keras.Model): "intermediate_size": self.intermediate_size, "intermediate_activation": self.intermediate_activation, "max_bp": self.max_bp, + "is_16S": self.is_16S, + "vocab_size": self.vocab_size, + "add_token": self.add_token, + "asv_dropout_rate": self.asv_dropout_rate, + "accumulation_steps": self.accumulation_steps, } ) return config diff --git a/aam/models/transformers.py b/aam/models/transformers.py index a48abf7402b123863c041cd44ac7bd26f7d3ddb2..29537684dcb81fbcc2e10ae33618b3948e43a020 100644 --- a/aam/models/transformers.py +++ b/aam/models/transformers.py @@ -14,6 +14,8 @@ class TransformerEncoder(tf.keras.layers.Layer): use_bias=False, norm_first=True, norm_epsilon=1e-6, + use_layer_norm=True, + share_rezero=True, **kwargs, ): super(TransformerEncoder, self).__init__(**kwargs) @@ -39,10 +41,12 @@ class TransformerEncoder(tf.keras.layers.Layer): inner_activation=self._activation, dropout_rate=self._dropout_rate, attention_dropout_rate=self._dropout_rate, + use_layer_norm=False, + share_rezero=True, name=("layer_%d" % i), ) ) - self.output_normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.output_normalization = tf.keras.layers.LayerNormalization() super(TransformerEncoder, self).build(input_shape) def get_config(self): @@ -76,10 +80,10 @@ class TransformerEncoder(tf.keras.layers.Layer): attention_mask = mask if attention_mask is not None: attention_mask = tf.matmul(attention_mask, attention_mask, transpose_b=True) - + encoder_inputs = encoder_inputs for layer_idx in range(self.num_layers): encoder_inputs = self.encoder_layers[layer_idx]( [encoder_inputs, attention_mask], training=training ) - output_tensor = encoder_inputs - return output_tensor + # output_tensor = self.output_normalization(encoder_inputs) + return encoder_inputs diff --git a/aam/models/unifrac_encoder.py b/aam/models/unifrac_encoder.py index dad09195e1b81b46f1d3e02bcf14d0353f92c2fc..94b0b8c082155c35cef944e83ddf9db97b1de8cf 100644 --- a/aam/models/unifrac_encoder.py +++ b/aam/models/unifrac_encoder.py @@ -7,7 +7,9 @@ import tensorflow as tf from aam.losses import PairwiseLoss from aam.models.base_sequence_encoder import BaseSequenceEncoder from aam.models.transformers import TransformerEncoder -from aam.utils import float_mask +from aam.optimizers.gradient_accumulator import GradientAccumulator +from aam.optimizers.loss_scaler import LossScaler +from aam.utils import apply_random_mask, float_mask @tf.keras.saving.register_keras_serializable(package="UniFracEncoder") @@ -22,6 +24,13 @@ class UniFracEncoder(tf.keras.Model): intermediate_size: int = 1024, intermediate_activation: str = "gelu", max_bp: int = 150, + include_alpha: bool = True, + is_16S: bool = True, + vocab_size: int = 6, + add_token: bool = True, + asv_dropout_rate: float = 0.0, + accumulation_steps: int = 1, + unifrac_metric: str = "faith_pd", **kwargs, ): super(UniFracEncoder, self).__init__(**kwargs) @@ -34,10 +43,22 @@ class UniFracEncoder(tf.keras.Model): self.intermediate_size = intermediate_size self.intermediate_activation = intermediate_activation self.max_bp = max_bp + self.include_alpha = include_alpha + self.is_16S = is_16S + self.vocab_size = vocab_size + self.add_token = add_token + self.asv_dropout_rate = asv_dropout_rate + self.accumulation_steps = accumulation_steps + self.unifrac_metric = unifrac_metric self.loss_tracker = tf.keras.metrics.Mean() - self.unifrac_loss = PairwiseLoss() - self.unifrac_tracker = tf.keras.metrics.Mean() + if self.unifrac_metric == "unifrac": + self.unifrac_loss = PairwiseLoss() + self.unifrac_out_dim = self.embedding_dim + else: + self.unifrac_loss = tf.keras.losses.MeanSquaredError(reduction="none") + self.unifrac_out_dim = 1 + self.encoder_tracker = tf.keras.metrics.Mean() # layers used in model self.base_encoder = BaseSequenceEncoder( @@ -52,16 +73,12 @@ class UniFracEncoder(tf.keras.Model): nuc_attention_layers=3, nuc_intermediate_size=128, intermediate_activation=self.intermediate_activation, + is_16S=self.is_16S, + vocab_size=self.vocab_size, + add_token=self.add_token, name="base_encoder", ) - self._unifrac_alpha = self.add_weight( - name="unifrac_alpha", - initializer=tf.keras.initializers.Zeros(), - trainable=True, - dtype=tf.float32, - ) - self.unifrac_encoder = TransformerEncoder( num_layers=self.attention_layers, num_attention_heads=self.attention_heads, @@ -70,12 +87,11 @@ class UniFracEncoder(tf.keras.Model): activation=self.intermediate_activation, name="unifrac_encoder", ) - - self.unifrac_ff = tf.keras.layers.Dense( - self.embedding_dim, use_bias=False, dtype=tf.float32, name="unifrac_ff" - ) + self.unifrac_ff = tf.keras.layers.Dense(self.unifrac_out_dim, dtype=tf.float32) self.loss_metrics = sorted(["loss", "target_loss", "count_mse"]) + self.gradient_accumulator = GradientAccumulator(self.accumulation_steps) + self.loss_scaler = LossScaler() def evaluate_metric(self, dataset, metric, **kwargs): metric_index = self.loss_metrics.index(metric) @@ -85,14 +101,16 @@ class UniFracEncoder(tf.keras.Model): def _compute_nuc_loss(self, nuc_tokens, nuc_pred): return self.base_encoder._compute_nuc_loss(nuc_tokens, nuc_pred) - def _compute_unifrac_loss( + def _compute_encoder_loss( self, y_true: tf.Tensor, unifrac_embeddings: tf.Tensor, ) -> tf.Tensor: loss = self.unifrac_loss(y_true, unifrac_embeddings) - num_samples = tf.reduce_sum(float_mask(loss)) - return tf.math.divide_no_nan(tf.reduce_sum(loss), num_samples) + if self.unifrac_metric == "unifrac": + batch = tf.cast(tf.shape(y_true)[0], dtype=tf.float32) + loss = tf.reduce_sum(loss, axis=1, keepdims=True) / (batch - 1) + return tf.reduce_mean(loss) def _compute_loss( self, @@ -102,10 +120,10 @@ class UniFracEncoder(tf.keras.Model): ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: nuc_tokens, counts = model_inputs embeddings, unifrac_embeddings, nuc_pred = outputs - tax_loss = self._compute_unifrac_loss(y_true, unifrac_embeddings) - nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) - loss = tax_loss + nuc_loss - return [loss, tax_loss, nuc_loss] + unifrac_loss = self._compute_encoder_loss(y_true, unifrac_embeddings) + # nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) * 0.0 + loss = unifrac_loss # + nuc_loss + return [loss, unifrac_loss, 0] # , nuc_loss] def predict_step( self, @@ -114,10 +132,10 @@ class UniFracEncoder(tf.keras.Model): tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], ], ): - inputs, sample_ids = data - embeddings, _, _ = self(inputs, training=False) + inputs, y = data + embeddings, unifrac_embeddings, nuc_pred = self(inputs, training=False) - return embeddings, sample_ids + return unifrac_embeddings def train_step( self, @@ -126,22 +144,34 @@ class UniFracEncoder(tf.keras.Model): tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]], ], ): - inputs, y = data + if not self.gradient_accumulator.built: + self.gradient_accumulator.build(self.optimizer, self) + inputs, y = data + y_target, unifrac_target = y with tf.GradientTape() as tape: outputs = self(inputs, training=True) - loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs) - - gradients = tape.gradient(loss, self.trainable_variables) - self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + loss, unifrac_loss, nuc_loss = self._compute_loss( + inputs, unifrac_target, outputs + ) + scaled_losses = self.loss_scaler([unifrac_loss]) + loss = tf.reduce_sum(tf.stack(scaled_losses, axis=0)) + + gradients = tape.gradient( + loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + self.gradient_accumulator.apply_gradients(gradients) self.loss_tracker.update_state(loss) - self.unifrac_tracker.update_state(tax_loss) + self.encoder_tracker.update_state(unifrac_loss) self.base_encoder.nuc_entropy.update_state(nuc_loss) return { "loss": self.loss_tracker.result(), - "unifrac_mse": self.unifrac_tracker.result(), + "unifrac_mse": self.encoder_tracker.result(), "nuc_entropy": self.base_encoder.nuc_entropy.result(), + "learning_rate": self.optimizer.learning_rate, } def test_step( @@ -152,17 +182,20 @@ class UniFracEncoder(tf.keras.Model): ], ): inputs, y = data - + y_target, unifrac_target = y outputs = self(inputs, training=False) - loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs) + loss, unifrac_loss, nuc_loss = self._compute_loss( + inputs, unifrac_target, outputs + ) self.loss_tracker.update_state(loss) - self.unifrac_tracker.update_state(tax_loss) + self.encoder_tracker.update_state(unifrac_loss) self.base_encoder.nuc_entropy.update_state(nuc_loss) return { "loss": self.loss_tracker.result(), - "unifrac_mse": self.unifrac_tracker.result(), + "unifrac_mse": self.encoder_tracker.result(), "nuc_entropy": self.base_encoder.nuc_entropy.result(), + "learning_rate": self.optimizer.learning_rate, } def call( @@ -173,7 +206,45 @@ class UniFracEncoder(tf.keras.Model): tokens = tf.cast(tokens, dtype=tf.int32) counts = tf.cast(counts, dtype=tf.int32) - sample_embeddings, nuc_embeddings = self.base_encoder(tokens, training=training) + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + random_mask = None + if training and self.asv_dropout_rate > 0: + random_mask = apply_random_mask(count_mask, self.asv_dropout_rate) + + sample_embeddings, nuc_embeddings = self.base_encoder( + tokens, random_mask=random_mask, training=training + ) + + if self.add_token: + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + unifrac_gated_embeddings = self.unifrac_encoder( + sample_embeddings, mask=count_attention_mask, training=training + ) + + if self.add_token: + unifrac_pred = unifrac_gated_embeddings[:, 0, :] + else: + mask = tf.cast(count_mask, dtype=tf.float32) + unifrac_pred = tf.reduce_sum(unifrac_gated_embeddings * mask, axis=1) + unifrac_pred /= tf.reduce_sum(mask, axis=1) + + unifrac_pred = self.unifrac_ff(unifrac_pred) + unifrac_embeddings = sample_embeddings + unifrac_gated_embeddings + # unifrac_embeddings = unifrac_gated_embeddings + return [unifrac_embeddings, unifrac_pred, nuc_embeddings] + + def base_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor] + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings = self.base_encoder.base_embeddings(tokens) # account for <SAMPLE> token count_mask = float_mask(counts, dtype=tf.int32) @@ -181,7 +252,7 @@ class UniFracEncoder(tf.keras.Model): count_attention_mask = count_mask unifrac_gated_embeddings = self.unifrac_encoder( - sample_embeddings, mask=count_attention_mask, training=training + sample_embeddings, mask=count_attention_mask ) unifrac_pred = unifrac_gated_embeddings[:, 0, :] unifrac_pred = self.unifrac_ff(unifrac_pred) @@ -190,7 +261,46 @@ class UniFracEncoder(tf.keras.Model): sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha ) - return [unifrac_embeddings, unifrac_pred, nuc_embeddings] + return unifrac_embeddings + + def asv_embeddings( + self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + asv_embeddings = self.base_encoder.asv_embeddings(tokens, training=training) + + return asv_embeddings + + def asv_gradient( + self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings + ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # keras cast all input to float so we need to manually cast to expected type + tokens, counts = inputs + tokens = tf.cast(tokens, dtype=tf.int32) + counts = tf.cast(counts, dtype=tf.int32) + + sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings) + + # account for <SAMPLE> token + count_mask = float_mask(counts, dtype=tf.int32) + count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1) + count_attention_mask = count_mask + + unifrac_gated_embeddings = self.unifrac_encoder( + sample_embeddings, mask=count_attention_mask + ) + unifrac_pred = unifrac_gated_embeddings[:, 0, :] + unifrac_pred = self.unifrac_ff(unifrac_pred) + + unifrac_embeddings = ( + sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha + ) + + return unifrac_embeddings def get_config(self): config = super(UniFracEncoder, self).get_config() @@ -204,6 +314,13 @@ class UniFracEncoder(tf.keras.Model): "intermediate_size": self.intermediate_size, "intermediate_activation": self.intermediate_activation, "max_bp": self.max_bp, + "include_alpha": self.include_alpha, + "is_16S": self.is_16S, + "vocab_size": self.vocab_size, + "add_token": self.add_token, + "asv_dropout_rate": self.asv_dropout_rate, + "accumulation_steps": self.accumulation_steps, + "unifrac_metric": self.unifrac_metric, } ) return config diff --git a/aam/models/utils.py b/aam/models/utils.py index 324d31d8dc4b1a1deb062d82dec356193884ff6e..65bb17f42f263bb552f6d90c36ed8f00af52393e 100644 --- a/aam/models/utils.py +++ b/aam/models/utils.py @@ -5,12 +5,9 @@ import tensorflow as tf class TransformerLearningRateSchedule( tf.keras.optimizers.schedules.LearningRateSchedule ): - def __init__( - self, d_model, warmup_steps=100, decay_method="cosine", initial_lr=3e-4 - ): + def __init__(self, warmup_steps=100, decay_method="cosine", initial_lr=3e-4): super(TransformerLearningRateSchedule, self).__init__() - self.d_model = d_model self.warmup_steps = warmup_steps self.decay_method = decay_method self.initial_lr = initial_lr @@ -24,12 +21,9 @@ class TransformerLearningRateSchedule( if self.decay_method == "cosine": # Cosine decay after warmup - cosine_decay = tf.keras.optimizers.schedules.CosineDecayRestarts( + cosine_decay = tf.keras.optimizers.schedules.CosineDecay( initial_learning_rate=self.initial_lr, - first_decay_steps=1000, # Change according to your training steps - t_mul=2.0, # How quickly to increase the restart periods - m_mul=0.9, # Multiplier for reducing max learning rate after each restart - alpha=0.0, # Minimum learning rate + first_decay_steps=self.warmup_steps, ) learning_rate = tf.cond( step < self.warmup_steps, @@ -51,7 +45,6 @@ class TransformerLearningRateSchedule( config = {} config.update( { - "d_model": self.d_model, "warmup_steps": self.warmup_steps, "decay_method": self.decay_method, "initial_lr": self.initial_lr, @@ -61,9 +54,15 @@ class TransformerLearningRateSchedule( def cos_decay_with_warmup(lr, warmup_steps=5000): - # Learning rate schedule: Warmup followed by cosine decay - lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts( - initial_learning_rate=lr, first_decay_steps=warmup_steps + # # Learning rate schedule: Warmup followed by cosine decay + # lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts( + # initial_learning_rate=lr, first_decay_steps=warmup_steps + # ) + lr_schedule = tf.keras.optimizers.schedules.CosineDecay( + initial_learning_rate=0.0, + decay_steps=4000, + warmup_target=lr, + warmup_steps=warmup_steps, ) return lr_schedule @@ -73,19 +72,23 @@ if __name__ == "__main__": import numpy as np import tensorflow as tf - # Learning rate schedule: Warmup followed by cosine decay - lr_schedule = cos_decay_with_warmup() + cosine_decay = tf.keras.optimizers.schedules.CosineDecay( + initial_learning_rate=0.0, + decay_steps=4000, + warmup_target=0.0003, + warmup_steps=100, + ) - # Compute learning rates for each step - steps = np.arange(10000) - learning_rates = [lr_schedule(step).numpy() for step in steps] + # Generate learning rates for a range of steps + steps = np.arange(4000, dtype=np.float32) + learning_rates = [cosine_decay(step).numpy() for step in steps] # Plot the learning rate schedule plt.figure(figsize=(10, 6)) plt.plot(steps, learning_rates) - plt.title("Learning Rate Schedule: Warmup + Cosine Decay") - plt.xlabel("Training Steps") + plt.xlabel("Training Step") plt.ylabel("Learning Rate") + plt.title("Transformer Learning Rate Schedule") + plt.legend() plt.grid(True) - plt.savefig("test.png") - plt.close() + plt.show() diff --git a/aam/optimizers/gradient_accumulator.py b/aam/optimizers/gradient_accumulator.py new file mode 100644 index 0000000000000000000000000000000000000000..943f1bf68032038aa6d0ec5343a2655546b2d43b --- /dev/null +++ b/aam/optimizers/gradient_accumulator.py @@ -0,0 +1,92 @@ +import tensorflow as tf + + +class GradientAccumulator: + def __init__(self, accumulation_steps): + self.accum_steps = tf.constant( + accumulation_steps, dtype=tf.int32, name="accum_steps" + ) + + self.accum_step_counter = tf.Variable( + 0, + dtype=tf.int32, + trainable=False, + name="accum_counter", + ) + # self.log_steps = tf.constant(100, dtype=tf.int32, name="accum_steps") + # self.log_step_counter = tf.Variable( + # 0, + # dtype=tf.int32, + # trainable=False, + # name="accum_counter", + # ) + self.built = False + # self.file_writer = tf.summary.create_file_writer( + # "/home/kalen/aam-research-exam/research-exam/healty-age-regression/results/logs/gradients" + # ) + + def build(self, optimizer, model): + self.built = True + self.optimizer = optimizer + self.model = model + + # reinitialize gradient accumulator + self.gradient_accumulation = [ + tf.Variable( + tf.zeros_like(v, dtype=tf.float32), + trainable=False, + name="accum_" + str(i), + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + for i, v in enumerate(self.model.trainable_variables) + ] + + def apply_gradients(self, gradients): + self.accum_step_counter.assign_add(1) + + # Accumulate batch gradients + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add( + gradients[i], + read_value=False, + ) + + # apply accumulated gradients + tf.cond( + tf.equal(self.accum_step_counter, self.accum_steps), + true_fn=self.apply_accu_gradients, + false_fn=lambda: None, + ) + + def apply_accu_gradients(self): + """Performs gradient update and resets slots afterwards.""" + self.optimizer.apply_gradients( + zip(self.gradient_accumulation, self.model.trainable_variables) + ) + + # self.log_step_counter.assign_add(1) + # tf.cond( + # tf.equal(self.log_step_counter, self.log_steps), + # true_fn=self.log_gradients, + # false_fn=lambda: None, + # ) + + # reset + self.accum_step_counter.assign(0) + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign( + tf.zeros_like(self.model.trainable_variables[i], dtype=tf.float32), + read_value=False, + ) + + # def log_gradients(self): + # self.log_step_counter.assign(0) + # with self.file_writer.as_default(): + # for grad, var in zip( + # self.gradient_accumulation, self.model.trainable_variables + # ): + # if grad is not None: + # tf.summary.histogram( + # var.name + "/gradient", grad, step=self.optimizer.iterations + # ) diff --git a/aam/optimizers/loss_scaler.py b/aam/optimizers/loss_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6242e584bd8b2be17fdfb715256dce48095214 --- /dev/null +++ b/aam/optimizers/loss_scaler.py @@ -0,0 +1,48 @@ +import tensorflow as tf + + +class LossScaler: + def __init__(self, gradient_accum_steps): + self.gradient_accum_steps = tf.cast(gradient_accum_steps, dtype=tf.float32) + self.moving_avg = None + self.decay = 0.99 + self.accum_loss = tf.Variable( + initial_value=1, + trainable=False, + dtype=tf.int32, + name="accum_loss", + ) + + def __call__(self, losses): + if self.moving_avg is None: + self.scaled_loss = [ + tf.Variable( + initial_value=loss, + trainable=False, + dtype=tf.float32, + name=f"scaled_loss_{i}", + ) + for i, loss in enumerate(losses) + ] + self.moving_avg = [ + tf.Variable( + initial_value=loss, + trainable=False, + dtype=tf.float32, + name=f"avg_loss_{i}", + ) + for i, loss in enumerate(losses) + ] + + for i in range(len(self.moving_avg)): + self.moving_avg[i].assign( + self.moving_avg[i] * self.decay + (1 - self.decay) * losses[i], + read_value=False, + ) + + return [ + tf.math.divide_no_nan( + losses[i], self.moving_avg[i] * self.gradient_accum_steps + ) + for i in range(len(self.moving_avg)) + ] diff --git a/aam/utils.py b/aam/utils.py index 03fb8009c5243ebb543208b20c4fd3afe0363fd3..410a52343bc26bacfe7a7946497f2fdd76de04a4 100644 --- a/aam/utils.py +++ b/aam/utils.py @@ -49,7 +49,8 @@ def masked_loss(sparse_cat: bool = False): total = tf.cast(tf.reduce_sum(mask), dtype=tf.float32) loss = tf.reduce_sum(loss * mask) - return tf.math.divide_no_nan(loss, total) + # return tf.math.divide_no_nan(loss, total) + return loss return wrapper