diff --git a/README.md b/README.md
index 057f0f72a84d932a51e1e476dc05737e51fd9065..72b1c500e5a08e4009b6c57ff29dc5df3ffbe884 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,8 @@ First create a new conda environment with unifrac
 
 `conda activate aam`
 
+`conda install -c conda-forge gxx_linux-64 hdf5 mkl-include lz4 hdf5-static libcblas liblapacke make`
+
 ## GPU Support 
 
 Install CUDA 11.8
diff --git a/aam/attention_cli.py b/aam/attention_cli.py
index 17c63613210d455e127813e22e4281f775c54b4e..9356eaff93f64b5c5a3bedbd05f017a990197471 100644
--- a/aam/attention_cli.py
+++ b/aam/attention_cli.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import datetime
 import os
-from typing import Optional
+from typing import Union
 
 import click
 import numpy as np
@@ -56,10 +56,28 @@ def validate_metadata(table, metadata, missing_samples_flag):
 @cli.command()
 @click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC)
 @click.option("--i-tree", required=True, type=click.Path(exists=True))
+@click.option(
+    "--m-metadata-file",
+    required=True,
+    help="Metadata description",
+    type=click.Path(exists=True),
+)
+@click.option(
+    "--m-metadata-column",
+    required=True,
+    type=str,
+    help="Numeric metadata column to use as prediction target.",
+)
+@click.option(
+    "--p-missing-samples",
+    default="error",
+    type=click.Choice(["error", "ignore"], case_sensitive=False),
+    help=MISSING_SAMP_DESC,
+)
 @click.option("--p-batch-size", default=8, show_default=True, required=False, type=int)
-@click.option("--p-max-bp", required=True, type=int)
 @click.option("--p-epochs", default=1000, show_default=True, type=int)
 @click.option("--p-dropout", default=0.0, show_default=True, type=float)
+@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float)
 @click.option("--p-patience", default=10, show_default=True, type=int)
 @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int)
 @click.option("--i-model", default=None, required=False, type=str)
@@ -70,30 +88,55 @@ def validate_metadata(table, metadata, missing_samples_flag):
 @click.option(
     "--p-intermediate-activation", default="relu", show_default=True, type=str
 )
-@click.option("--p-asv-limit", default=512, show_default=True, type=int)
+@click.option("--p-asv-limit", default=1024, show_default=True, type=int)
+@click.option("--p-gen-new-table", default=True, show_default=True, type=bool)
+@click.option("--p-lr", default=1e-4, show_default=True, type=float)
+@click.option("--p-warmup-steps", default=10000, show_default=True, type=int)
+@click.option("--p-max-bp", default=150, show_default=True, type=int)
 @click.option("--output-dir", required=True)
+@click.option("--p-add-token", default=False, required=False, type=bool)
+@click.option("--p-gotu", default=False, required=False, type=bool)
+@click.option("--p-is-categorical", default=False, required=False, type=bool)
+@click.option("--p-rarefy-depth", default=5000, required=False, type=int)
+@click.option("--p-weight-decay", default=0.004, show_default=True, type=float)
+@click.option("--p-accumulation-steps", default=1, required=False, type=int)
+@click.option("--p-unifrac-metric", default="unifrac", required=False, type=str)
 def fit_unifrac_regressor(
     i_table: str,
     i_tree: str,
+    m_metadata_file: str,
+    m_metadata_column: str,
+    p_missing_samples: bool,
     p_batch_size: int,
-    p_max_bp: int,
     p_epochs: int,
     p_dropout: float,
+    p_asv_dropout: float,
     p_patience: int,
     p_early_stop_warmup: int,
-    i_model: Optional[str],
+    i_model: Union[None, str],
     p_embedding_dim: int,
     p_attention_heads: int,
     p_attention_layers: int,
     p_intermediate_size: int,
     p_intermediate_activation: str,
     p_asv_limit: int,
+    p_gen_new_table: bool,
+    p_lr: float,
+    p_warmup_steps: int,
+    p_max_bp: int,
     output_dir: str,
+    p_add_token: bool,
+    p_gotu: bool,
+    p_is_categorical: bool,
+    p_rarefy_depth: int,
+    p_weight_decay: float,
+    p_accumulation_steps,
+    p_unifrac_metric: str,
 ):
     from biom import load_table
 
     from aam.data_handlers import UniFracGenerator
-    from aam.models import UniFracEncoder
+    from aam.models import SequenceEncoder
     from aam.models.utils import cos_decay_with_warmup
 
     if not os.path.exists(output_dir):
@@ -103,20 +146,39 @@ def fit_unifrac_regressor(
     if not os.path.exists(figure_path):
         os.makedirs(figure_path)
 
+    output_dim = p_embedding_dim
+    if p_unifrac_metric == "faith_pd":
+        output_dim = 1
+
     if i_model is not None:
-        model: tf.keras.Model = tf.keras.models.load_model(i_model)
+        model = tf.keras.models.load_model(i_model)
     else:
-        model: tf.keras.Model = UniFracEncoder(
+        model: tf.keras.Model = SequenceEncoder(
+            output_dim,
             p_asv_limit,
-            embedding_dim=p_embedding_dim,
+            p_unifrac_metric,
             dropout_rate=p_dropout,
+            embedding_dim=p_embedding_dim,
             attention_heads=p_attention_heads,
             attention_layers=p_attention_layers,
             intermediate_size=p_intermediate_size,
             intermediate_activation=p_intermediate_activation,
+            max_bp=p_max_bp,
+            is_16S=True,
+            add_token=p_add_token,
+            asv_dropout_rate=p_asv_dropout,
+            accumulation_steps=p_accumulation_steps,
         )
 
-        optimizer = tf.keras.optimizers.AdamW(cos_decay_with_warmup(), beta_2=0.95)
+        optimizer = tf.keras.optimizers.AdamW(
+            cos_decay_with_warmup(p_lr, p_warmup_steps),
+            beta_2=0.98,
+            weight_decay=p_weight_decay,
+            # use_ema=True,
+            # ema_momentum=0.999,
+            # ema_overwrite_frequency=500,
+            # global_clipnorm=1.0,
+        )
         token_shape = tf.TensorShape([None, None, 150])
         count_shape = tf.TensorShape([None, None, 1])
         model.build([token_shape, count_shape])
@@ -127,10 +189,14 @@ def fit_unifrac_regressor(
     model.summary()
 
     table = load_table(i_table)
-    ids = table.ids()
+    df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[
+        [m_metadata_column]
+    ]
+    ids, table, df = validate_metadata(table, df, p_missing_samples)
     indices = np.arange(len(ids), dtype=np.int32)
+
     np.random.shuffle(indices)
-    train_size = int(len(ids) * 0.9)
+    train_size = int(len(ids) * 0.8)
 
     train_indices = indices[:train_size]
     train_ids = ids[train_indices]
@@ -140,21 +206,36 @@ def fit_unifrac_regressor(
     val_ids = ids[val_indices]
     val_table = table.filter(val_ids, inplace=False)
 
+    common_kwargs = {
+        "metadata_column": m_metadata_column,
+        "max_token_per_sample": p_asv_limit,
+        "rarefy_depth": p_rarefy_depth,
+        "batch_size": p_batch_size,
+        "is_16S": True,
+        "is_categorical": p_is_categorical,
+        "max_bp": p_max_bp,
+        "epochs": p_epochs,
+        "tree_path": i_tree,
+        "metadata": df,
+        "unifrac_metric": p_unifrac_metric,
+    }
     train_gen = UniFracGenerator(
         table=train_table,
-        tree_path=i_tree,
-        max_token_per_sample=p_asv_limit,
         shuffle=True,
-        gen_new_tables=True,
-        batch_size=p_batch_size,
+        shift=0.0,
+        scale=1.0,
+        gen_new_tables=p_gen_new_table,
+        **common_kwargs,
     )
     train_data = train_gen.get_data()
 
     val_gen = UniFracGenerator(
         table=val_table,
-        tree_path=i_tree,
-        max_token_per_sample=p_asv_limit,
         shuffle=False,
+        shift=0.0,
+        scale=1.0,
+        gen_new_tables=False,
+        **common_kwargs,
     )
     val_data = val_gen.get_data()
 
@@ -186,10 +267,29 @@ def fit_unifrac_regressor(
 @cli.command()
 @click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC)
 @click.option("--i-taxonomy", required=True, type=click.Path(exists=True))
-@click.option("--p-tax-level", default=7, type=int)
-@click.option("--p-max-bp", required=True, type=int)
+@click.option("--i-tax-level", default=7, type=int)
+@click.option(
+    "--m-metadata-file",
+    required=True,
+    help="Metadata description",
+    type=click.Path(exists=True),
+)
+@click.option(
+    "--m-metadata-column",
+    required=True,
+    type=str,
+    help="Numeric metadata column to use as prediction target.",
+)
+@click.option(
+    "--p-missing-samples",
+    default="error",
+    type=click.Choice(["error", "ignore"], case_sensitive=False),
+    help=MISSING_SAMP_DESC,
+)
+@click.option("--p-batch-size", default=8, show_default=True, required=False, type=int)
 @click.option("--p-epochs", default=1000, show_default=True, type=int)
 @click.option("--p-dropout", default=0.1, show_default=True, type=float)
+@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float)
 @click.option("--p-patience", default=10, show_default=True, type=int)
 @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int)
 @click.option("--i-model", default=None, required=False, type=str)
@@ -201,29 +301,53 @@ def fit_unifrac_regressor(
     "--p-intermediate-activation", default="relu", show_default=True, type=str
 )
 @click.option("--p-asv-limit", default=512, show_default=True, type=int)
+@click.option("--p-gen-new-table", default=True, show_default=True, type=bool)
+@click.option("--p-lr", default=1e-4, show_default=True, type=float)
+@click.option("--p-warmup-steps", default=10000, show_default=True, type=int)
+@click.option("--p-max-bp", required=True, type=int)
 @click.option("--output-dir", required=True)
+@click.option("--p-add-token", default=False, required=False, type=bool)
+@click.option("--p-gotu", default=False, required=False, type=bool)
+@click.option("--p-is-categorical", default=False, required=False, type=bool)
+@click.option("--p-rarefy-depth", default=5000, required=False, type=int)
+@click.option("--p-weight-decay", default=0.004, show_default=True, type=float)
+@click.option("--p-accumulation-steps", default=1, required=False, type=int)
 def fit_taxonomy_regressor(
     i_table: str,
     i_taxonomy: str,
-    p_tax_level: int,
-    p_max_bp: int,
+    i_tax_level: int,
+    m_metadata_file: str,
+    m_metadata_column: str,
+    p_missing_samples: bool,
+    p_batch_size: int,
     p_epochs: int,
     p_dropout: float,
+    p_asv_dropout: float,
     p_patience: int,
     p_early_stop_warmup: int,
-    i_model: Optional[str],
+    i_model: Union[None, str],
     p_embedding_dim: int,
     p_attention_heads: int,
     p_attention_layers: int,
     p_intermediate_size: int,
     p_intermediate_activation: str,
     p_asv_limit: int,
+    p_gen_new_table: bool,
+    p_lr: float,
+    p_warmup_steps: int,
+    p_max_bp: int,
     output_dir: str,
+    p_add_token: bool,
+    p_gotu: bool,
+    p_is_categorical: bool,
+    p_rarefy_depth: int,
+    p_weight_decay: float,
+    p_accumulation_steps,
 ):
     from biom import load_table
 
     from aam.data_handlers import TaxonomyGenerator
-    from aam.models import TaxonomyEncoder
+    from aam.models import SequenceEncoder
     from aam.models.utils import cos_decay_with_warmup
 
     if not os.path.exists(output_dir):
@@ -234,10 +358,14 @@ def fit_taxonomy_regressor(
         os.makedirs(figure_path)
 
     table = load_table(i_table)
-    ids = table.ids()
+    df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[
+        [m_metadata_column]
+    ]
+    ids, table, df = validate_metadata(table, df, p_missing_samples)
     indices = np.arange(len(ids), dtype=np.int32)
+
     np.random.shuffle(indices)
-    train_size = int(len(ids) * 0.9)
+    train_size = int(len(ids) * 0.8)
 
     train_indices = indices[:train_size]
     train_ids = ids[train_indices]
@@ -247,22 +375,36 @@ def fit_taxonomy_regressor(
     val_ids = ids[val_indices]
     val_table = table.filter(val_ids, inplace=False)
 
+    common_kwargs = {
+        "metadata_column": m_metadata_column,
+        "max_token_per_sample": p_asv_limit,
+        "rarefy_depth": p_rarefy_depth,
+        "batch_size": p_batch_size,
+        "is_16S": True,
+        "is_categorical": p_is_categorical,
+        "max_bp": p_max_bp,
+        "epochs": p_epochs,
+        "taxonomy": i_taxonomy,
+        "tax_level": i_tax_level,
+        "metadata": df,
+    }
     train_gen = TaxonomyGenerator(
         table=train_table,
-        taxonomy=i_taxonomy,
-        tax_level=p_tax_level,
-        max_token_per_sample=p_asv_limit,
         shuffle=True,
-        gen_new_tables=True,
+        shift=0.0,
+        scale=1.0,
+        gen_new_tables=p_gen_new_table,
+        **common_kwargs,
     )
     train_data = train_gen.get_data()
 
     val_gen = TaxonomyGenerator(
         table=val_table,
-        taxonomy=i_taxonomy,
-        tax_level=p_tax_level,
-        max_token_per_sample=p_asv_limit,
         shuffle=False,
+        shift=0.0,
+        scale=1.0,
+        gen_new_tables=False,
+        **common_kwargs,
     )
     val_data = val_gen.get_data()
 
@@ -283,18 +425,28 @@ def fit_taxonomy_regressor(
     if i_model is not None:
         model: tf.keras.Model = tf.keras.models.load_model(i_model)
     else:
-        model: tf.keras.Model = TaxonomyEncoder(
+        model: tf.keras.Model = SequenceEncoder(
             train_gen.num_tokens,
             p_asv_limit,
-            embedding_dim=p_embedding_dim,
+            "taxonomy",
             dropout_rate=p_dropout,
+            embedding_dim=p_embedding_dim,
             attention_heads=p_attention_heads,
             attention_layers=p_attention_layers,
             intermediate_size=p_intermediate_size,
             intermediate_activation=p_intermediate_activation,
+            max_bp=p_max_bp,
+            is_16S=True,
+            add_token=p_add_token,
+            asv_dropout_rate=p_asv_dropout,
+            accumulation_steps=p_accumulation_steps,
+        )
+        optimizer = tf.keras.optimizers.AdamW(
+            cos_decay_with_warmup(p_lr, p_warmup_steps),
+            beta_2=0.98,
+            weight_decay=p_weight_decay,
+            # global_clipnorm=1.0,
         )
-
-        optimizer = tf.keras.optimizers.AdamW(cos_decay_with_warmup(), beta_2=0.95)
         token_shape = tf.TensorShape([None, None, 150])
         count_shape = tf.TensorShape([None, None, 1])
         model.build([token_shape, count_shape])
@@ -361,6 +513,7 @@ def fit_taxonomy_regressor(
 @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int)
 @click.option("--p-batch-size", default=8, show_default=True, required=False, type=int)
 @click.option("--p-dropout", default=0.1, show_default=True, type=float)
+@click.option("--p-asv-dropout", default=0.0, show_default=True, type=float)
 @click.option("--p-report-back", default=5, show_default=True, type=int)
 @click.option("--p-asv-limit", default=1024, show_default=True, type=int)
 @click.option("--p-penalty", default=1.0, show_default=True, type=float)
@@ -380,6 +533,14 @@ def fit_taxonomy_regressor(
 @click.option("--p-warmup-steps", default=10000, show_default=True, type=int)
 @click.option("--p-max-bp", default=150, show_default=True, type=int)
 @click.option("--output-dir", required=True, type=click.Path(exists=False))
+@click.option("--p-output-dim", default=1, required=False, type=int)
+@click.option("--p-add-token", default=False, required=False, type=bool)
+@click.option("--p-gotu", default=False, required=False, type=bool)
+@click.option("--p-is-categorical", default=False, required=False, type=bool)
+@click.option("--p-rarefy-depth", default=5000, required=False, type=int)
+@click.option("--p-weight-decay", default=0.004, show_default=True, type=float)
+@click.option("--p-accumulation-steps", default=1, required=False, type=int)
+@click.option("--p-unifrac-metric", default="unifrac", required=False, type=str)
 def fit_sample_regressor(
     i_table: str,
     i_base_model_path: str,
@@ -394,6 +555,7 @@ def fit_sample_regressor(
     p_early_stop_warmup: int,
     p_batch_size: int,
     p_dropout: float,
+    p_asv_dropout: float,
     p_report_back: int,
     p_asv_limit: int,
     p_penalty: float,
@@ -411,11 +573,22 @@ def fit_sample_regressor(
     p_warmup_steps,
     p_max_bp: int,
     output_dir: str,
+    p_output_dim: int,
+    p_add_token: bool,
+    p_gotu: bool,
+    p_is_categorical: bool,
+    p_rarefy_depth: int,
+    p_weight_decay: float,
+    p_accumulation_steps: int,
+    p_unifrac_metric: str,
 ):
-    from aam.callbacks import MeanAbsoluteError
-    from aam.data_handlers import TaxonomyGenerator, UniFracGenerator
-    from aam.models import SequenceRegressor, TaxonomyEncoder, UniFracEncoder
+    from aam.callbacks import ConfusionMatrx, MeanAbsoluteError
+    from aam.data_handlers import CombinedGenerator, TaxonomyGenerator, UniFracGenerator
+    from aam.models import SequenceRegressor
 
+    # p_is_16S = False
+    is_16S = not p_gotu
+    # p_is_categorical = True
     if not os.path.exists(output_dir):
         os.makedirs(output_dir)
 
@@ -428,11 +601,14 @@ def fit_sample_regressor(
         os.makedirs(model_path)
 
     table = load_table(i_table)
-    df = pd.read_csv(m_metadata_file, sep="\t", index_col=0)[[m_metadata_column]]
+    df = pd.read_csv(m_metadata_file, sep="\t", index_col=0, dtype={0: str})[
+        [m_metadata_column]
+    ]
     ids, table, df = validate_metadata(table, df, p_missing_samples)
     num_ids = len(ids)
 
     fold_indices = np.arange(num_ids)
+    np.random.shuffle(fold_indices)
     if p_test_size > 0:
         test_size = int(num_ids * p_test_size)
         train_size = num_ids - test_size
@@ -443,10 +619,11 @@ def fit_sample_regressor(
 
     common_kwargs = {
         "metadata_column": m_metadata_column,
-        "is_categorical": False,
         "max_token_per_sample": p_asv_limit,
-        "rarefy_depth": 5000,
+        "rarefy_depth": p_rarefy_depth,
         "batch_size": p_batch_size,
+        "is_16S": is_16S,
+        "is_categorical": p_is_categorical,
     }
 
     def tax_gen(table, df, shuffle, shift, scale, epochs, gen_new_tables):
@@ -475,28 +652,48 @@ def fit_sample_regressor(
             epochs=epochs,
             gen_new_tables=gen_new_tables,
             max_bp=p_max_bp,
+            unifrac_metric=p_unifrac_metric,
             **common_kwargs,
         )
 
-    if p_taxonomy is not None and p_tree is None:
+    def combine_gen(table, df, shuffle, shift, scale, epochs, gen_new_tables):
+        return CombinedGenerator(
+            table=table,
+            metadata=df,
+            tree_path=p_tree,
+            taxonomy=p_taxonomy,
+            tax_level=p_taxonomy_level,
+            shuffle=shuffle,
+            shift=shift,
+            scale=scale,
+            epochs=epochs,
+            gen_new_tables=gen_new_tables,
+            max_bp=p_max_bp,
+            **common_kwargs,
+        )
+
+    if p_unifrac_metric == "combined":
+        base_model = "combined"
+        generator = combine_gen
+    elif p_taxonomy is not None and p_tree is None:
         base_model = "taxonomy"
         generator = tax_gen
     elif p_taxonomy is None and p_tree is not None:
-        base_model = "unifrac"
+        base_model = p_unifrac_metric
         generator = unifrac_gen
     else:
         raise Exception("Only taxonomy or UniFrac is supported.")
 
     if i_base_model_path is not None:
         base_model = tf.keras.models.load_model(i_base_model_path, compile=False)
-        if isinstance(base_model, TaxonomyEncoder):
+        base_type = base_model.encoder_type
+        if base_type == "taxonomy":
             generator = tax_gen
-        elif isinstance(base_model, UniFracEncoder):
-            generator = unifrac_gen
         else:
-            raise Exception(f"Unsupported base model {type(base_model)}")
+            generator = unifrac_gen
+
         if not p_no_freeze_base_weights:
-            raise Warning("base_model's weights are set to trainable.")
+            print("base_model's weights are set to trainable.")
 
     def _get_fold(
         indices,
@@ -521,15 +718,24 @@ def fit_sample_regressor(
             data["num_tokens"] = None
         return data
 
-    kfolds = KFold(p_cv)
+    if not p_is_categorical:
+        print("non-stratified folds")
+        kfolds = KFold(p_cv)
+        splits = kfolds.split(fold_indices)
+    else:
+        print("stratified folds...")
+        kfolds = StratifiedKFold(p_cv)
+        train_ids = ids[fold_indices]
+        train_classes = df.loc[df.index.isin(train_ids), m_metadata_column].values
+        splits = kfolds.split(fold_indices, train_classes)
 
     models = []
-    for i, (train_ind, val_ind) in enumerate(kfolds.split(fold_indices)):
+    for i, (train_ind, val_ind) in enumerate(splits):
         train_data = _get_fold(
             train_ind,
             shuffle=True,
             shift=0.0,
-            scale=100.0,
+            scale=1.0,
             gen_new_tables=p_gen_new_table,
         )
         val_data = _get_fold(
@@ -542,30 +748,74 @@ def fit_sample_regressor(
         with open(os.path.join(model_path, f"f{i}_val_ids.txt"), "w") as f:
             for id in ids[val_ind]:
                 f.write(id + "\n")
+        vocab_size = 6 if not p_is_categorical else 2000
+
+        if base_model == "combined":
+            base_output_dim = [p_embedding_dim, 1, train_data["num_tokens"]]
+        elif base_model == "unifrac":
+            base_output_dim = p_embedding_dim
+        elif base_model == "faith_pd":
+            base_output_dim = 1
+        else:
+            base_output_dim = train_data["num_tokens"]
 
         model = SequenceRegressor(
             token_limit=p_asv_limit,
+            base_output_dim=base_output_dim,
+            shift=train_data["shift"],
+            scale=train_data["scale"],
+            dropout_rate=p_dropout,
             embedding_dim=p_embedding_dim,
             attention_heads=p_attention_heads,
             attention_layers=p_attention_layers,
             intermediate_size=p_intermediate_size,
             intermediate_activation=p_intermediate_activation,
-            shift=train_data["shift"],
-            scale=train_data["scale"],
-            dropout_rate=p_dropout,
             base_model=base_model,
             freeze_base=p_no_freeze_base_weights,
-            num_tax_levels=train_data["num_tokens"],
             penalty=p_penalty,
             nuc_penalty=p_nuc_penalty,
             max_bp=p_max_bp,
+            is_16S=is_16S,
+            vocab_size=vocab_size,
+            out_dim=p_output_dim,
+            classifier=p_is_categorical,
+            add_token=p_add_token,
+            class_weights=train_data["class_weights"],
+            accumulation_steps=p_accumulation_steps,
         )
         token_shape = tf.TensorShape([None, None, p_max_bp])
         count_shape = tf.TensorShape([None, None, 1])
         model.build([token_shape, count_shape])
         model.summary()
-        loss = tf.keras.losses.MeanSquaredError(reduction="none")
+
         fold_label = i + 1
+        if not p_is_categorical:
+            loss = tf.keras.losses.MeanSquaredError(reduction="none")
+            callbacks = [
+                MeanAbsoluteError(
+                    monitor="val_mae",
+                    dataset=val_data["dataset"],
+                    output_dir=os.path.join(
+                        figure_path, f"model_f{fold_label}-val.png"
+                    ),
+                    report_back=p_report_back,
+                )
+            ]
+        else:
+            loss = tf.keras.losses.CategoricalFocalCrossentropy(
+                from_logits=False, reduction="none"
+            )
+            # loss = tf.keras.losses.CategoricalHinge(reduction="none")
+            callbacks = [
+                ConfusionMatrx(
+                    monitor="val_target_loss",
+                    dataset=val_data["dataset"],
+                    output_dir=os.path.join(
+                        figure_path, f"model_f{fold_label}-val.png"
+                    ),
+                    report_back=p_report_back,
+                )
+            ]
         model_cv = CVModel(
             model,
             train_data,
@@ -573,25 +823,18 @@ def fit_sample_regressor(
             output_dir,
             fold_label,
         )
+        metric = "mae" if not p_is_categorical else "target_loss"
         model_cv.fit_fold(
             loss,
             p_epochs,
             os.path.join(model_path, f"model_f{fold_label}.keras"),
-            metric="mae",
+            metric=metric,
             patience=p_patience,
             early_stop_warmup=p_early_stop_warmup,
-            callbacks=[
-                MeanAbsoluteError(
-                    monitor="val_mae",
-                    dataset=val_data["dataset"],
-                    output_dir=os.path.join(
-                        figure_path, f"model_f{fold_label}-val.png"
-                    ),
-                    report_back=p_report_back,
-                )
-            ],
+            callbacks=[*callbacks],
             lr=p_lr,
             warmup_steps=p_warmup_steps,
+            weight_decay=p_weight_decay,
         )
         models.append(model_cv)
         print(f"Fold {i+1} mae: {model_cv.metric_value}")
diff --git a/aam/callbacks.py b/aam/callbacks.py
index 915cc4051881d138996366f7e82c8dcbcb90fbc7..f0e1a9f166d0ee034d6ec56a445c046a7744b299 100644
--- a/aam/callbacks.py
+++ b/aam/callbacks.py
@@ -37,17 +37,17 @@ def _mean_absolute_error(pred_val, true_val, fname, labels=None):
 def _confusion_matrix(pred_val, true_val, fname, cat_labels=None):
     cf_matrix = tf.math.confusion_matrix(true_val, pred_val).numpy()
     group_counts = ["{0:0.0f}".format(value) for value in cf_matrix.flatten()]
-    group_percentages = [
-        "{0:.2%}".format(value) for value in cf_matrix.flatten() / np.sum(cf_matrix)
-    ]
+
+    cf_matrix = cf_matrix / np.sum(cf_matrix, axis=-1, keepdims=True)
+    group_percentages = ["{0:.2%}".format(value) for value in cf_matrix.flatten()]
     labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts, group_percentages)]
     labels = np.asarray(labels).reshape(cf_matrix.shape)
     fig, ax = plt.subplots(figsize=(10, 10))
     ax = sns.heatmap(
         cf_matrix,
         annot=labels,
-        xticklabels=cat_labels,
-        yticklabels=cat_labels,
+        # xticklabels=cat_labels,
+        # yticklabels=cat_labels,
         fmt="",
     )
     import textwrap
@@ -85,16 +85,30 @@ class MeanAbsoluteError(tf.keras.callbacks.Callback):
 
 
 class ConfusionMatrx(tf.keras.callbacks.Callback):
-    def __init__(self, dataset, output_dir, report_back, labels, **kwargs):
+    def __init__(
+        self,
+        dataset,
+        output_dir,
+        report_back,
+        labels=None,
+        monitor="val_loss",
+        **kwargs,
+    ):
         super().__init__(**kwargs)
         self.output_dir = output_dir
         self.report_back = report_back
         self.dataset = dataset
         self.labels = labels
+        self.best_metric = None
+        self.monitor = monitor
 
     def on_epoch_end(self, epoch, logs=None):
-        y_pred, y_true = self.model.predict(self.dataset)
-        _confusion_matrix(y_pred, y_true, self.output_dir, self.labels)
+        metric = logs[self.monitor]
+        print(self.best_metric, metric)
+        if self.best_metric is None or self.best_metric > metric:
+            y_pred, y_true = self.model.predict(self.dataset)
+            _confusion_matrix(y_pred, y_true, self.output_dir, self.labels)
+            self.best_metric = metric
 
 
 class SaveModel(tf.keras.callbacks.Callback):
@@ -107,11 +121,8 @@ class SaveModel(tf.keras.callbacks.Callback):
         self.monitor = monitor
 
     def on_epoch_end(self, epoch, logs=None):
-        learning_rate = float(
-            tf.keras.backend.get_value(self.model.optimizer.learning_rate)
-        )
-        # Add the learning rate to the logs dictionary
-        logs["learning_rate"] = learning_rate
+        iterations = float(tf.keras.backend.get_value(self.model.optimizer.iterations))
+        logs["iteration"] = iterations
 
         metric = logs[self.monitor]
         if self.best_weights is None or self.best_metric > metric:
diff --git a/aam/cv_utils.py b/aam/cv_utils.py
index eb22607851e2df1670a002ab97c638ed3ca5048e..5ecf2708ecab54dab01b1c076180b6c7b9171c82 100644
--- a/aam/cv_utils.py
+++ b/aam/cv_utils.py
@@ -30,33 +30,41 @@ class CVModel:
 
     def fit_fold(
         self,
-        loss,
-        epochs,
-        model_save_path,
-        metric="mae",
-        patience=10,
-        early_stop_warmup=50,
-        callbacks=[],
-        lr=1e-4,
-        warmup_steps=10000,
+        loss: tf.keras.losses.Loss,
+        epochs: int,
+        model_save_path: str,
+        metric: str = "loss",
+        patience: int = 10,
+        early_stop_warmup: int = 50,
+        callbacks: list[tf.keras.callbacks.Callback] = [],
+        lr: float = 1e-4,
+        warmup_steps: int = 10000,
+        weight_decay: float = 0.004,
     ):
         if not os.path.exists(self.log_dir):
             os.makedirs(self.log_dir)
-        optimizer = tf.keras.optimizers.AdamW(
-            cos_decay_with_warmup(lr, warmup_steps), beta_2=0.95
+        print(f"weight decay: {weight_decay}")
+        optimizer = tf.keras.optimizers.Adam(
+            cos_decay_with_warmup(lr, warmup_steps),
+            beta_2=0.98,
+            # weight_decay=weight_decay,
+            # global_clipnorm=1.0,
+            # use_ema=True,
+            # ema_momentum=0.999,
+            # ema_overwrite_frequency=500,
         )
         model_saver = SaveModel(model_save_path, 10, f"val_{metric}")
         core_callbacks = [
             tf.keras.callbacks.TensorBoard(
-                log_dir=self.log_dir,
-                histogram_freq=0,
+                log_dir=self.log_dir, histogram_freq=0, write_graph=False
             ),
             tf.keras.callbacks.EarlyStopping(
                 "val_loss", patience=patience, start_from_epoch=early_stop_warmup
             ),
             model_saver,
         ]
-        self.model.compile(optimizer=optimizer, loss=loss, run_eagerly=False)
+        self.model.compile(optimizer=optimizer, loss=loss)
+        # Set up the summary writer
         self.model.fit(
             self.train_data["dataset"],
             validation_data=self.val_data["dataset"],
diff --git a/aam/data_handlers/__init__.py b/aam/data_handlers/__init__.py
index 57a7ded36556f1939af53b3d754f2def90981eb9..8650d764315c3585c9b9828632404ae76b2723a4 100644
--- a/aam/data_handlers/__init__.py
+++ b/aam/data_handlers/__init__.py
@@ -1,11 +1,13 @@
 from __future__ import annotations
 
+from .combined_generator import CombinedGenerator
 from .generator_dataset import GeneratorDataset
 from .sequence_dataset import SequenceDataset
 from .taxonomy_generator import TaxonomyGenerator
 from .unifrac_generator import UniFracGenerator
 
 __all__ = [
+    "CombinedGenerator",
     "SequenceDataset",
     "TaxonomyGenerator",
     "UniFracGenerator",
diff --git a/aam/data_handlers/combined_generator.py b/aam/data_handlers/combined_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..39b0becb45fcb3f4b0928f003edfb697d16d4796
--- /dev/null
+++ b/aam/data_handlers/combined_generator.py
@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import os
+from typing import Iterable, Union
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from biom import Table
+from biom.util import biom_open
+from skbio import DistanceMatrix
+from sklearn import preprocessing
+from unifrac import faith_pd, unweighted
+
+from aam.data_handlers.generator_dataset import GeneratorDataset, add_lock
+
+
+class CombinedGenerator(GeneratorDataset):
+    taxon_field = "Taxon"
+    levels = [f"Level {i}" for i in range(1, 8)]
+    taxonomy_fn = "taxonomy.tsv"
+
+    def __init__(self, tree_path: str, taxonomy, tax_level, **kwargs):
+        super().__init__(**kwargs)
+        self.tree_path = tree_path
+        if self.batch_size % 2 != 0:
+            raise Exception("Batch size must be multiple of 2")
+
+        self.tax_level = f"Level {tax_level}"
+        self.taxonomy = taxonomy
+
+        self.encoder_target = self._create_encoder_target(self.rarefy_table)
+        self.encoder_dtype = (np.float32, np.float32, np.float32, np.int32)
+        self.encoder_output_type = (
+            tf.TensorSpec(shape=[self.batch_size, self.batch_size], dtype=tf.float32),
+            tf.TensorSpec(shape=[self.batch_size, 1], dtype=tf.float32),
+            tf.TensorSpec(shape=[self.batch_size, None], dtype=tf.int32),
+        )
+
+    def _create_encoder_target(self, table: Table) -> DistanceMatrix:
+        if not hasattr(self, "tree_path"):
+            return None
+
+        random = np.random.random(1)[0]
+        temp_path = f"/tmp/temp{random}.biom"
+        with biom_open(temp_path, "w") as f:
+            table.to_hdf5(f, "aam")
+        uni = unweighted(temp_path, self.tree_path)
+        faith = faith_pd(temp_path, self.tree_path)
+        os.remove(temp_path)
+
+        obs = table.ids(axis="observation")
+        tax = self.taxonomy.loc[obs, "token"]
+        return (uni, faith, tax)
+
+    def _encoder_output(
+        self,
+        encoder_target: Iterable[object],
+        sample_ids: Iterable[str],
+        obs_ids: list[str],
+    ) -> np.ndarray[float]:
+        uni, faith, tax = encoder_target
+        uni = uni.filter(sample_ids).data
+        faith = faith.loc[sample_ids].to_numpy().reshape((-1, 1))
+
+        tax_tokens = [tax.loc[obs] for obs in obs_ids]
+        max_len = max([len(tokens) for tokens in tax_tokens])
+        tax_tokens = np.array([np.pad(t, [[0, max_len - len(t)]]) for t in tax_tokens])
+        return (uni, faith, tax_tokens)
+
+    @property
+    def taxonomy(self) -> pd.DataFrame:
+        return self._taxonomy
+
+    @taxonomy.setter
+    @add_lock
+    def taxonomy(self, taxonomy: Union[str, pd.DataFrame]):
+        if hasattr(self, "_taxon_set"):
+            raise Exception("Taxon already set")
+        if taxonomy is None:
+            self._taxonomy = taxonomy
+            return
+
+        if isinstance(taxonomy, str):
+            taxonomy = pd.read_csv(taxonomy, sep="\t", index_col=0)
+            if self.taxon_field not in taxonomy.columns:
+                raise Exception("Invalid taxonomy: missing 'Taxon' field")
+
+        taxonomy[self.levels] = taxonomy[self.taxon_field].str.split("; ", expand=True)
+        taxonomy = taxonomy.loc[self._table.ids(axis="observation")]
+
+        if self.tax_level not in self.levels:
+            raise Exception(f"Invalid level: {self.tax_level}")
+
+        level_index = self.levels.index(self.tax_level)
+        levels = self.levels[: level_index + 1]
+        taxonomy = taxonomy.loc[:, levels]
+        taxonomy.loc[:, "class"] = taxonomy.loc[:, levels].agg("; ".join, axis=1)
+
+        le = preprocessing.LabelEncoder()
+        taxonomy.loc[:, "token"] = le.fit_transform(taxonomy["class"])
+        taxonomy.loc[:, "token"] += 1  # shifts tokens to be between 1 and n
+        print(
+            "min token:", min(taxonomy["token"]), "max token:", max(taxonomy["token"])
+        )
+        self.num_tokens = (
+            max(taxonomy["token"]) + 1
+        )  # still need to add 1 to account for shift
+        self._taxonomy = taxonomy
+
+
+if __name__ == "__main__":
+    import numpy as np
+
+    from aam.data_handlers import CombinedGenerator
+
+    ug = CombinedGenerator(
+        table="/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-no-duplicate-host-bloom-filtered-5000-small-stool-only-very-small.biom",
+        tree_path="/home/kalen/aam-research-exam/research-exam/agp/data/agp-aligned.nwk",
+        taxonomy="/home/kalen/aam-research-exam/research-exam/healty-age-regression/taxonomy.tsv",
+        tax_level=7,
+        metadata="/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-healthy.txt",
+        metadata_column="host_age",
+        shift=0.0,
+        scale=100.0,
+        gen_new_tables=True,
+    )
+    data = ug.get_data()
+    for i, (x, y) in enumerate(data["dataset"]):
+        print(y)
+        break
+
+    # data = ug.get_data_by_id(ug.rarefy_tables.ids()[:16])
+    # for x, y in data["dataset"]:
+    #     print(y)
+    #     break
diff --git a/aam/data_handlers/generator_dataset.py b/aam/data_handlers/generator_dataset.py
index 9ff4f0da8bfe89878b3e470a7c6b813c3ea15dd2..1c776bc19146a3483fda1380df24ce70ca4318f5 100644
--- a/aam/data_handlers/generator_dataset.py
+++ b/aam/data_handlers/generator_dataset.py
@@ -55,7 +55,6 @@ class GeneratorDataset:
         table: Union[str, Table],
         metadata: Optional[Union[str, pd.DataFrame]] = None,
         metadata_column: Optional[str] = None,
-        is_categorical: Optional[bool] = None,
         shift: Optional[Union[str, float]] = None,
         scale: Union[str, float] = "minmax",
         max_token_per_sample: int = 1024,
@@ -65,12 +64,15 @@ class GeneratorDataset:
         gen_new_tables: bool = False,
         batch_size: int = 8,
         max_bp: int = 150,
+        is_16S: bool = True,
+        is_categorical: Optional[bool] = None,
     ):
         table, metadata
         self.table = table
 
         self.metadata_column = metadata_column
         self.is_categorical = is_categorical
+        self.include_sample_weight = is_categorical
         self.shift = shift
         self.scale = scale
         self.metadata = metadata
@@ -84,6 +86,7 @@ class GeneratorDataset:
 
         self.batch_size = batch_size
         self.max_bp = max_bp
+        self.is_16S = is_16S
 
         self.preprocessed_table = self.table
         self.obs_ids = self.preprocessed_table.ids(axis="observation")
@@ -142,7 +145,7 @@ class GeneratorDataset:
             return
         self._validate_dataframe(metadata)
         if isinstance(metadata, str):
-            metadata = pd.read_csv(metadata, sep="\t", index_col=0)
+            metadata = pd.read_csv(metadata, sep="\t", index_col=0, dtype={0: str})
 
         if self.metadata_column not in metadata.columns:
             raise Exception(f"Invalid metadata column {self.metadata_column}")
@@ -200,6 +203,9 @@ class GeneratorDataset:
         if not (y_data, pd.Series):
             raise Exception(f"Invalid y_data object: {type(y_data)}")
 
+        if not self.is_categorical:
+            return y_data.loc[sample_ids].to_numpy().reshape(-1, 1)
+
         return y_data.loc[sample_ids].to_numpy().reshape(-1, 1)
 
     def _create_table_data(
@@ -208,10 +214,12 @@ class GeneratorDataset:
         obs_ids = table.ids(axis="observation")
         sample_ids = table.ids()
 
-        obs_encodings = tf.cast(
-            tf.strings.unicode_decode(obs_ids, "UTF-8"), dtype=tf.int64
-        )
-        obs_encodings = self.lookup_table.lookup(obs_encodings).numpy()
+        obs_encodings = None
+        if self.is_16S:
+            obs_encodings = tf.cast(
+                tf.strings.unicode_decode(obs_ids, "UTF-8"), dtype=tf.int64
+            )
+            obs_encodings = self.lookup_table.lookup(obs_encodings).numpy()
 
         table_data, row, col, shape = self._table_data(table)
 
@@ -233,7 +241,11 @@ class GeneratorDataset:
             s_mask = row == s
             s_counts = counts[s_mask]
             s_obs_indices = col[s_mask]
-            s_tokens = obs_encodings[s_obs_indices]
+            if self.is_16S:
+                s_tokens = obs_encodings[s_obs_indices]
+            else:
+                s_tokens = np.reshape(s_obs_indices, newshape=(-1, 1)) + 1
+
             sorted_order = np.argsort(s_counts)
             sorted_order = sorted_order[::-1]
             s_counts = s_counts[sorted_order].reshape(-1, 1)
@@ -248,7 +260,7 @@ class GeneratorDataset:
 
         if s_max_token > self.max_token_per_sample:
             print(f"\tskipping group due to exceeding token limit {s_max_token}...")
-            return None, None, None, None
+            return None, None, None, None, None
 
         s_ids = [sample_ids[s] for s in samples]
 
@@ -260,7 +272,7 @@ class GeneratorDataset:
             encoder_target = self.encoder_target
         encoder_output = self._encoder_output(encoder_target, s_ids, s_obj_ids)
 
-        return s_counts, s_tokens, y_output, encoder_output
+        return s_counts, s_tokens, y_output, encoder_output, s_obj_ids
 
     def _epoch_complete(self, processed):
         if processed < self.steps_per_epoch:
@@ -298,7 +310,7 @@ class GeneratorDataset:
 
         return table_data, y_data, encoder_target, sample_indices
 
-    def _create_epoch_generator(self):
+    def _create_epoch_generator(self, include_seq_id):
         def generator():
             processed = 0
             table_data = self.table_data
@@ -322,7 +334,9 @@ class GeneratorDataset:
                     )
 
                 while not self._epoch_complete(processed):
-                    counts, tokens, y_output, encoder_out = sample_data(minibatch)
+                    counts, tokens, y_output, encoder_out, ob_ids = sample_data(
+                        minibatch
+                    )
 
                     if counts is not None:
                         max_len = max([len(c) for c in counts])
@@ -332,6 +346,12 @@ class GeneratorDataset:
                         padded_tokens = np.array(
                             [np.pad(t, [[0, max_len - len(t)], [0, 0]]) for t in tokens]
                         )
+                        padded_ob_ids = np.array(
+                            [
+                                np.pad(o, [[0, max_len - len(o)]], constant_values="")
+                                for o in ob_ids
+                            ]
+                        )
 
                         processed += 1
                         table_output = (
@@ -344,14 +364,29 @@ class GeneratorDataset:
                             output = y_output.astype(np.float32)
 
                         if encoder_out is not None:
-                            if output is None:
-                                output = encoder_out.astype(self.encoder_dtype)
+                            if isinstance(encoder_out, tuple):
+                                encoder_out = tuple(
+                                    [
+                                        o.astype(t)
+                                        for o, t in zip(encoder_out, self.encoder_dtype)
+                                    ]
+                                )
                             else:
+                                encoder_out = encoder_out.astype(self.encoder_dtype)
+
+                            if output is not None:
                                 output = (
                                     output,
-                                    encoder_out.astype(self.encoder_dtype),
+                                    encoder_out,
                                 )
+                            else:
+                                output = encoder_out
 
+                        if include_seq_id:
+                            output = (
+                                *output,
+                                padded_ob_ids,
+                            )
                         if output is not None:
                             yield (table_output, output)
                         else:
@@ -386,8 +421,8 @@ class GeneratorDataset:
 
         return generator
 
-    def get_data(self):
-        generator = self._create_epoch_generator()
+    def get_data(self, include_seq_id=False):
+        generator = self._create_epoch_generator(include_seq_id)
         output_sig = (
             tf.TensorSpec(shape=[self.batch_size, None, self.max_bp], dtype=tf.int32),
             tf.TensorSpec(shape=[self.batch_size, None, 1], dtype=tf.int32),
@@ -404,7 +439,23 @@ class GeneratorDataset:
                 y_output_sig = (y_output_sig, self.encoder_output_type)
 
         if y_output_sig is not None:
+            if include_seq_id:
+                y_output_sig = (
+                    *y_output_sig,
+                    tf.TensorSpec(
+                        shape=(self.batch_size, None), dtype=tf.string, name=None
+                    ),
+                )
             output_sig = (output_sig, y_output_sig)
+        class_weights = None
+        if self.is_categorical:
+            counts = self._metadata.value_counts().to_dict()
+            counts[0.0] = 0
+            class_weights = [0] * len(counts)
+            for k, v in counts.items():
+                class_weights[int(k)] = 1 / np.sqrt(v)
+            class_weights[0] = 0
+        print(class_weights)
         dataset: tf.data.Dataset = tf.data.Dataset.from_generator(
             generator,
             output_signature=output_sig,
@@ -417,6 +468,7 @@ class GeneratorDataset:
             "scale": self.scale,
             "size": self.size,
             "steps_pre_epoch": self.steps_per_epoch,
+            "class_weights": class_weights,
         }
         return data_obj
 
diff --git a/aam/data_handlers/tests/generator_dataset_test.py b/aam/data_handlers/tests/generator_dataset_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..33a2b509294294b3c86d514d087d4a59e8af52b9
--- /dev/null
+++ b/aam/data_handlers/tests/generator_dataset_test.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from aam.data_handlers.generator_dataset import GeneratorDataset
+
+
+def test_generator_dataset():
+    """Generator should output a single
+    tensor for y. Both iteration should also
+    return same value.
+    """
+
+    table = "/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-no-duplicate-host-bloom-filtered-5000-small-stool-only-very-small.biom"
+    metadata = "/home/kalen/aam-research-exam/research-exam/healty-age-regression/agp-healthy.txt"
+    generator = GeneratorDataset(
+        table=table,
+        metadata=metadata,
+        metadata_column="host_age",
+        gen_new_tables=False,
+        shuffle=False,
+        epochs=1,
+    )
+
+    table = generator.table
+    metadata = generator.metadata
+    assert table.shape[1] == metadata.shape[0]
+    assert np.all(np.equal(table.ids(), metadata.index))
+
+    batch_size = 8
+    s_ids = table.ids()[: 5 * batch_size]
+    y_true = metadata.loc[s_ids]
+
+    data1 = generator.get_data()
+    data2 = generator.get_data()
+
+    data_ys = []
+    for ((token1, count1), y1), ((token2, count2), y2) in zip(
+        data1["dataset"].take(5), data2["dataset"].take(5)
+    ):
+        assert np.all(np.equal(token1.numpy(), token2.numpy()))
+        assert np.all(np.equal(count1.numpy(), count2.numpy()))
+        assert np.all(np.equal(y1.numpy(), y2.numpy()))
+        data_ys.append(y1)
+
+    data_ys = np.concatenate(data_ys)
+    assert np.all(np.equal(y_true.to_numpy().reshape(-1), data_ys.reshape(-1)))
diff --git a/aam/data_handlers/unifrac_generator.py b/aam/data_handlers/unifrac_generator.py
index ef4e48ac74a55604dd705c325348fa9d33569f7a..3e6dc2a840dbb5544e54a81d286a45c5a5572961 100644
--- a/aam/data_handlers/unifrac_generator.py
+++ b/aam/data_handlers/unifrac_generator.py
@@ -8,22 +8,29 @@ import tensorflow as tf
 from biom import Table
 from biom.util import biom_open
 from skbio import DistanceMatrix
-from unifrac import unweighted
+from unifrac import faith_pd, unweighted
 
 from aam.data_handlers.generator_dataset import GeneratorDataset
 
 
 class UniFracGenerator(GeneratorDataset):
-    def __init__(self, tree_path: str, **kwargs):
+    def __init__(self, tree_path: str, unifrac_metric="unifrac", **kwargs):
         super().__init__(**kwargs)
         self.tree_path = tree_path
+        self.unifrac_metric = unifrac_metric
+
         if self.batch_size % 2 != 0:
             raise Exception("Batch size must be multiple of 2")
         self.encoder_target = self._create_encoder_target(self.rarefy_table)
         self.encoder_dtype = np.float32
-        self.encoder_output_type = tf.TensorSpec(
-            shape=[self.batch_size, self.batch_size], dtype=tf.float32
-        )
+        if self.unifrac_metric == "unifrac":
+            self.encoder_output_type = tf.TensorSpec(
+                shape=[self.batch_size, self.batch_size], dtype=tf.float32
+            )
+        else:
+            self.encoder_output_type = tf.TensorSpec(
+                shape=[self.batch_size, 1], dtype=tf.float32
+            )
 
     def _create_encoder_target(self, table: Table) -> DistanceMatrix:
         if not hasattr(self, "tree_path"):
@@ -33,7 +40,10 @@ class UniFracGenerator(GeneratorDataset):
         temp_path = f"/tmp/temp{random}.biom"
         with biom_open(temp_path, "w") as f:
             table.to_hdf5(f, "aam")
-        distances = unweighted(temp_path, self.tree_path)
+        if self.unifrac_metric == "unifrac":
+            distances = unweighted(temp_path, self.tree_path)
+        else:
+            distances = faith_pd(temp_path, self.tree_path)
         os.remove(temp_path)
         return distances
 
@@ -43,10 +53,15 @@ class UniFracGenerator(GeneratorDataset):
         sample_ids: Iterable[str],
         ob_ids: list[str],
     ) -> np.ndarray[float]:
-        return encoder_target.filter(sample_ids).data
+        if self.unifrac_metric == "unifrac":
+            return encoder_target.filter(sample_ids).data
+        else:
+            return encoder_target.loc[sample_ids].to_numpy().reshape((-1, 1))
 
 
 if __name__ == "__main__":
+    import numpy as np
+
     from aam.data_handlers import UniFracGenerator
 
     ug = UniFracGenerator(
@@ -60,7 +75,7 @@ if __name__ == "__main__":
     )
     data = ug.get_data()
     for i, (x, y) in enumerate(data["dataset"]):
-        print(y)
+        print(np.mean(y[1]))
         break
 
     # data = ug.get_data_by_id(ug.rarefy_tables.ids()[:16])
diff --git a/aam/layers.py b/aam/layers.py
index 36848394ee2c21d46228fb5097054262fc23f70d..94d92053f82d99fa986922fe79cda07114c30ba5 100644
--- a/aam/layers.py
+++ b/aam/layers.py
@@ -38,6 +38,7 @@ class ASVEncoder(tf.keras.layers.Layer):
         dropout_rate,
         intermediate_ff,
         intermediate_activation="gelu",
+        add_token=True,
         **kwargs,
     ):
         super(ASVEncoder, self).__init__(**kwargs)
@@ -47,13 +48,23 @@ class ASVEncoder(tf.keras.layers.Layer):
         self.dropout_rate = dropout_rate
         self.intermediate_ff = intermediate_ff
         self.intermediate_activation = intermediate_activation
+        self.add_token = add_token
         self.base_tokens = 6
         self.num_tokens = self.base_tokens * self.max_bp + 2
+
+        self.asv_token = self.num_tokens - 1
+        self.nucleotide_position = tf.range(
+            0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32
+        )
+
+    def build(self, input_shape):
         self.emb_layer = tf.keras.layers.Embedding(
             self.num_tokens,
-            32,
+            128,
             input_length=self.max_bp,
-            embeddings_initializer=tf.keras.initializers.GlorotNormal(),
+            embeddings_initializer=tf.keras.initializers.RandomNormal(
+                mean=0, stddev=128**0.5
+            ),
         )
 
         self.avs_attention = NucleotideAttention(
@@ -61,20 +72,18 @@ class ASVEncoder(tf.keras.layers.Layer):
             num_heads=self.attention_heads,
             num_layers=self.attention_layers,
             dropout=self.dropout_rate,
-            intermediate_ff=intermediate_ff,
+            intermediate_ff=self.intermediate_ff,
             intermediate_activation=self.intermediate_activation,
         )
-        self.asv_token = self.num_tokens - 1
-        self.nucleotide_position = tf.range(
-            0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32
-        )
+        super(ASVEncoder, self).build(input_shape)
 
     def call(self, inputs, training=False):
         seq = inputs
         seq = seq + self.nucleotide_position
 
         # add <ASV> token
-        seq = tf.pad(seq, [[0, 0], [0, 0], [0, 1]], constant_values=self.asv_token)
+        if self.add_token:
+            seq = tf.pad(seq, [[0, 0], [0, 0], [0, 1]], constant_values=self.asv_token)
 
         output = self.emb_layer(seq)
         output = self.avs_attention(output, training=training)
@@ -97,6 +106,7 @@ class ASVEncoder(tf.keras.layers.Layer):
                 "dropout_rate": self.dropout_rate,
                 "intermediate_ff": self.intermediate_ff,
                 "intermediate_activation": self.intermediate_activation,
+                "add_token": self.add_token,
             }
         )
         return config
@@ -196,8 +206,13 @@ class NucleotideAttention(tf.keras.layers.Layer):
         self.epsilon = 1e-6
         self.intermediate_ff = intermediate_ff
         self.intermediate_activation = intermediate_activation
+
+    def build(self, input_shape):
         self.pos_emb = tfm.nlp.layers.PositionEmbedding(
-            self.max_bp + 1, seq_axis=2, name="nuc_pos"
+            self.max_bp + 1,
+            seq_axis=2,
+            initializer=tf.keras.initializers.RandomNormal(mean=0, stddev=128**0.5),
+            name="nuc_pos",
         )
         self.attention_layers = []
         for i in range(self.num_layers):
@@ -206,7 +221,7 @@ class NucleotideAttention(tf.keras.layers.Layer):
                     num_heads=self.num_heads,
                     dropout=self.dropout,
                     epsilon=self.epsilon,
-                    intermediate_ff=intermediate_ff,
+                    intermediate_ff=self.intermediate_ff,
                     intermediate_activation=self.intermediate_activation,
                     name=("layer_%d" % i),
                 )
@@ -214,15 +229,17 @@ class NucleotideAttention(tf.keras.layers.Layer):
         self.output_normalization = tf.keras.layers.LayerNormalization(
             epsilon=self.epsilon, dtype=tf.float32
         )
+        super(NucleotideAttention, self).build(input_shape)
 
     def call(self, attention_input, attention_mask=None, training=False):
         attention_input = attention_input + self.pos_emb(attention_input)
+        attention_input = attention_input * (9 * 3) ** (-0.25)
         for layer_idx in range(self.num_layers):
             attention_input = self.attention_layers[layer_idx](
                 attention_input, training=training
             )
-        output = self.output_normalization(attention_input)
-        return output
+        # output = self.output_normalization(attention_input)
+        return attention_input
 
     def get_config(self):
         config = super(NucleotideAttention, self).get_config()
@@ -255,6 +272,14 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer):
         self.dropout = dropout
         self.epsilon = epsilon
         self.intermediate_ff = intermediate_ff
+        self.intermediate_activation = intermediate_activation
+
+    def build(self, input_shape):
+        self._shape = input_shape
+        self.nucleotides = input_shape[2]
+        self.hidden_dim = input_shape[3]
+        self.head_size = tf.cast(self.hidden_dim / self.num_heads, dtype=tf.int32)
+
         self.attention_norm = tf.keras.layers.LayerNormalization(
             epsilon=self.epsilon, dtype=tf.float32
         )
@@ -263,19 +288,23 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer):
         self.ff_norm = tf.keras.layers.LayerNormalization(
             epsilon=self.epsilon, dtype=tf.float32
         )
-        self.intermediate_activation = intermediate_activation
-
-    def build(self, input_shape):
-        self._shape = input_shape
-        self.nucleotides = input_shape[2]
-        self.hidden_dim = input_shape[3]
-        self.head_size = tf.cast(self.hidden_dim / self.num_heads, dtype=tf.int32)
+        self.nuc_alpha = self.add_weight(
+            name="nuc_alpha",
+            initializer=tf.keras.initializers.Zeros(),
+            trainable=True,
+            dtype=tf.float32,
+        )
 
         wi_shape = [1, 1, self.num_heads, self.hidden_dim, self.head_size]
         self.w_qi = self.add_weight("w_qi", wi_shape, trainable=True, dtype=tf.float32)
         self.w_ki = self.add_weight("w_ki", wi_shape, trainable=True, dtype=tf.float32)
         self.w_vi = self.add_weight("w_kv", wi_shape, trainable=True, dtype=tf.float32)
-        self.o_dense = tf.keras.layers.Dense(self.hidden_dim, use_bias=False)
+
+        wo_shape = [1, 1, self.hidden_dim, self.hidden_dim]
+        self.o_dense = self.add_weight(
+            "w_o", wo_shape, trainable=True, dtype=tf.float32
+        )
+        # self.o_dense = tf.keras.layers.Dense(self.hidden_dim, use_bias=False)
 
         self.scale_dot_factor = tf.math.sqrt(
             tf.cast(self.head_size, dtype=self.compute_dtype)
@@ -303,7 +332,7 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer):
     def scaled_dot_attention(self, attention_input):
         wq_tensor = self.compute_wi(attention_input, self.w_qi)
         wk_tensor = self.compute_wi(attention_input, self.w_ki)
-        wv_tensor = self.compute_wi(attention_input, self.w_vi)
+        wv_tensor = self.compute_wi(attention_input, self.w_vi * (0.67 * 3) ** -0.25)
 
         # (multihead) scaled dot product attention sublayer
         # [B, A, H, N, S] => [B, A, H, N, N]
@@ -332,20 +361,23 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer):
             attention_output,
             shape=[batch_size, num_asv, self.nucleotides, self.hidden_dim],
         )
-        attention_output = self.o_dense(attention_output)
+        attention_output = tf.matmul(
+            attention_output, self.o_dense * (0.67 * 3) ** (-0.25)
+        )
+        # attention_output = self.o_dense(attention_output)
         attention_output = tf.ensure_shape(attention_output, self._shape)
         return attention_output
 
     def call(self, attention_input, training=False):
-        # scaled dot product attention sublayer
-        attention_input = self.attention_norm(attention_input)
+        # # scaled dot product attention sublayer
+        # attention_input = self.attention_norm(attention_input)
 
         # cast for mixed precision
         _attention_input = tf.cast(attention_input, dtype=self.compute_dtype)
 
         # cast back to float32
         _attention_output = self.scaled_dot_attention(_attention_input)
-        attention_output = tf.cast(_attention_output, dtype=tf.float32)
+        attention_output = tf.cast(_attention_output, dtype=tf.float32) * self.nuc_alpha
 
         # residual connection
         attention_output = tf.add(attention_input, attention_output)
@@ -353,13 +385,14 @@ class NucleotideAttentionBlock(tf.keras.layers.Layer):
         attention_output = self.attention_dropout(attention_output, training=training)
 
         # cast for mixed precision
-        ff_input = self.ff_norm(attention_output)
+        # ff_input = self.ff_norm(attention_output)
+        ff_input = attention_output  # self.ff_norm(attention_output)
         _ff_input = tf.cast(ff_input, dtype=self.compute_dtype)
         _ff_output = self.inter_ff(_ff_input)
         _ff_output = self.outer_ff(_ff_output)
 
-        # cast back to float32
-        ff_output = tf.cast(_ff_output, dtype=tf.float32)
+        # cast back to float32, residual connection
+        ff_output = tf.cast(_ff_output, dtype=tf.float32) * self.nuc_alpha
         ff_output = tf.add(ff_input, ff_output)
 
         ff_output = tf.ensure_shape(ff_output, self._shape)
diff --git a/aam/losses.py b/aam/losses.py
index 575a370877cda3dc89939efc24cfea5593f35b66..68318b275434a35ba6470917aa54fceaef57eafa 100644
--- a/aam/losses.py
+++ b/aam/losses.py
@@ -61,7 +61,7 @@ def _pairwise_distances(embeddings, squared=False):
         # (ex: on the diagonal)
         # we need to add a small epsilon where distances == 0.0
         mask = tf.cast(tf.equal(distances, 0.0), tf.float32)
-        distances = distances + mask * 1e-07
+        distances = distances + mask * 1e-12
 
         distances = tf.sqrt(distances)
 
@@ -79,7 +79,6 @@ class PairwiseLoss(tf.keras.losses.Loss):
     def call(self, y_true, y_pred):
         y_pred_dist = _pairwise_distances(y_pred, squared=False)
         differences = tf.math.square(y_pred_dist - y_true)
-        differences = tf.linalg.band_part(differences, 0, -1)
         return differences
 
 
diff --git a/aam/models/__init__.py b/aam/models/__init__.py
index 2b5c6973a26dda5d88485624abf36df0d1925b48..e5da9c11cddd75f216ba25ab7cec8e42eb642ac3 100644
--- a/aam/models/__init__.py
+++ b/aam/models/__init__.py
@@ -1,12 +1,14 @@
 from __future__ import annotations
 
 from .base_sequence_encoder import BaseSequenceEncoder
+from .sequence_encoder import SequenceEncoder
 from .sequence_regressor import SequenceRegressor
 from .taxonomy_encoder import TaxonomyEncoder
 from .unifrac_encoder import UniFracEncoder
 
 __all__ = [
     "BaseSequenceEncoder",
+    "SequenceEncoder",
     "SequenceRegressor",
     "TaxonomyEncoder",
     "UniFracEncoder",
diff --git a/aam/models/base_sequence_encoder.py b/aam/models/base_sequence_encoder.py
index 9cd901a9e73d951118dc6eb92e8b81f7bd4469fb..2af54efc1ad8bf5ec24204020f180ec7759d582b 100644
--- a/aam/models/base_sequence_encoder.py
+++ b/aam/models/base_sequence_encoder.py
@@ -7,7 +7,7 @@ from aam.layers import (
     ASVEncoder,
 )
 from aam.models.transformers import TransformerEncoder
-from aam.utils import float_mask, masked_loss
+from aam.utils import float_mask
 
 
 @tf.keras.saving.register_keras_serializable(package="BaseSequenceEncoder")
@@ -25,6 +25,9 @@ class BaseSequenceEncoder(tf.keras.layers.Layer):
         nuc_attention_layers: int = 4,
         nuc_intermediate_size: int = 1024,
         intermediate_activation: str = "gelu",
+        is_16S: bool = True,
+        vocab_size: int = 6,
+        add_token: bool = True,
         **kwargs,
     ):
         super(BaseSequenceEncoder, self).__init__(**kwargs)
@@ -39,29 +42,45 @@ class BaseSequenceEncoder(tf.keras.layers.Layer):
         self.nuc_attention_layers = nuc_attention_layers
         self.nuc_intermediate_size = nuc_intermediate_size
         self.intermediate_activation = intermediate_activation
-        self.nuc_loss = tf.keras.losses.SparseCategoricalCrossentropy(
-            ignore_class=0, from_logits=False, reduction="none"
-        )
-        self.nuc_entropy = tf.keras.metrics.Mean()
+        self.is_16S = is_16S
+        self.vocab_size = vocab_size
+        self.add_token = add_token
 
         # layers used in model
-        self.asv_encoder = ASVEncoder(
-            max_bp,
-            nuc_attention_heads,
-            nuc_attention_layers,
-            0.0,
-            nuc_intermediate_size,
-            intermediate_activation=self.intermediate_activation,
-            name="asv_encoder",
-        )
-        self.nuc_logits = tf.keras.layers.Dense(
-            6, use_bias=False, name="nuc_logits", dtype=tf.float32, activation="softmax"
+        if self.is_16S:
+            self.asv_encoder = ASVEncoder(
+                max_bp,
+                nuc_attention_heads,
+                nuc_attention_layers,
+                0.0,
+                nuc_intermediate_size,
+                intermediate_activation=self.intermediate_activation,
+                add_token=self.add_token,
+                name="asv_encoder",
+            )
+        else:
+            self.asv_embeddings = tf.keras.layers.Embedding(
+                self.vocab_size,
+                output_dim=self.embedding_dim,
+                embeddings_initializer=tf.keras.initializers.RandomNormal(
+                    mean=0, stddev=self.embedding_dim**0.5
+                ),
+            )
+            self.asv_encoder = TransformerEncoder(
+                num_layers=self.sample_attention_layers,
+                num_attention_heads=self.sample_attention_heads,
+                intermediate_size=self.sample_intermediate_size,
+                activation=self.intermediate_activation,
+                dropout_rate=self.dropout_rate,
+            )
+
+        self.asv_pos = tfm.nlp.layers.PositionEmbedding(
+            self.token_limit + 5,
+            initializer=tf.keras.initializers.RandomNormal(
+                mean=0, stddev=self.embedding_dim**0.5
+            ),
         )
 
-        self.asv_scale = tf.keras.layers.Dense(self.embedding_dim, use_bias=False)
-        self.asv_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
-        self.asv_pos = tfm.nlp.layers.PositionEmbedding(self.token_limit + 5)
-
         self.sample_encoder = TransformerEncoder(
             num_layers=self.sample_attention_layers,
             num_attention_heads=self.sample_attention_heads,
@@ -69,25 +88,14 @@ class BaseSequenceEncoder(tf.keras.layers.Layer):
             activation=self.intermediate_activation,
             dropout_rate=self.dropout_rate,
         )
-        self.sample_token = self.add_weight(
-            "sample_token",
-            [1, 1, self.embedding_dim],
-            dtype=tf.float32,
-            initializer=tf.keras.initializers.Zeros(),
-            trainable=True,
-        )
-        self._base_alpha = self.add_weight(
-            name="base_alpha",
-            initializer=tf.keras.initializers.Zeros(),
-            trainable=True,
-            dtype=tf.float32,
-        )
-
-        self.linear_activation = tf.keras.layers.Activation("linear", dtype=tf.float32)
-
-    @masked_loss(sparse_cat=True)
-    def _compute_nuc_loss(self, tokens: tf.Tensor, pred_tokens: tf.Tensor) -> tf.Tensor:
-        return self.nuc_loss(tokens, pred_tokens)
+        if self.add_token:
+            self.sample_token = self.add_weight(
+                "sample_token",
+                [1, 1, self.embedding_dim],
+                dtype=tf.float32,
+                initializer="glorot_uniform",
+                trainable=True,
+            )
 
     def _add_sample_token(self, tensor: tf.Tensor) -> tf.Tensor:
         # add <SAMPLE> token empbedding
@@ -102,45 +110,108 @@ class BaseSequenceEncoder(tf.keras.layers.Layer):
         return embeddings
 
     def _split_asvs(self, embeddings):
-        nuc_embeddings = embeddings[:, :, :-1, :]
-        nucleotides = self.nuc_logits(nuc_embeddings)
+        asv_embeddings = embeddings
+        if self.is_16S:
+            if self.add_token:
+                asv_embeddings = asv_embeddings[:, :, 0, :]
+            else:
+                asv_embeddings = tf.reduce_mean(asv_embeddings, axis=2)
+        else:
+            asv_embeddings = asv_embeddings[:, :, 0, :]
 
-        asv_embeddings = embeddings[:, :, 0, :]
-        asv_embeddings = self.asv_norm(asv_embeddings)
         asv_embeddings = asv_embeddings + self.asv_pos(asv_embeddings)
-
-        return asv_embeddings, nucleotides
+        return asv_embeddings
 
     def call(
+        self, inputs: tf.Tensor, random_mask: bool = None, training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor]:
+        # need to cast inputs to int32 to avoid error
+        # because keras converts all inputs
+        # to float when calling build()
+        asv_input = tf.cast(inputs, dtype=tf.int32)
+        asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True))
+
+        if training and random_mask is not None:
+            asv_input = asv_input * tf.cast(random_mask, dtype=tf.int32)
+
+        if self.is_16S:
+            embeddings = self.asv_encoder(asv_input, training=training)
+        else:
+            embeddings = self.asv_embeddings(asv_input)
+        asv_embeddings = self._split_asvs(embeddings)
+
+        if self.add_token:
+            asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+            sample_embeddings = self._add_sample_token(asv_embeddings)
+        else:
+            sample_embeddings = asv_embeddings
+        return sample_embeddings
+
+    # def base_embeddings(
+    #     self, inputs: tf.Tensor, training: bool = False
+    # ) -> tuple[tf.Tensor, tf.Tensor]:
+    #     # need to cast inputs to int32 to avoid error
+    #     # because keras converts all inputs
+    #     # to float when calling build()
+    #     asv_input = tf.cast(inputs, dtype=tf.int32)
+
+    #     embeddings = self.asv_encoder(asv_input, training=training)
+    #     embeddings = self.asv_scale(embeddings)
+    #     asv_embeddings, nucleotides = self._split_asvs(embeddings)
+
+    #     asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True))
+    #     padded_asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+
+    #     # padded embeddings are the skip connection
+    #     # normal asv embeddings continue through next block
+    #     padded_asv_embeddings = tf.pad(
+    #         asv_embeddings, [[0, 0], [1, 0], [0, 0]], constant_values=0
+    #     )
+
+    #     sample_gated_embeddings = self._add_sample_token(asv_embeddings)
+    #     sample_gated_embeddings = self.sample_encoder(
+    #         sample_gated_embeddings, mask=padded_asv_mask, training=training
+    #     )
+
+    #     sample_embeddings = (
+    #         padded_asv_embeddings + sample_gated_embeddings * self._base_alpha
+    #     )
+    #     return sample_embeddings
+
+    def get_asv_embeddings(
         self, inputs: tf.Tensor, training: bool = False
     ) -> tuple[tf.Tensor, tf.Tensor]:
+        print("holyfucking shit")
         # need to cast inputs to int32 to avoid error
         # because keras converts all inputs
         # to float when calling build()
         asv_input = tf.cast(inputs, dtype=tf.int32)
+        asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True))
 
-        embeddings = self.asv_encoder(asv_input, training=training)
-        embeddings = self.asv_scale(embeddings)
+        if self.is_16S:
+            embeddings = self.asv_encoder(asv_input, training=training)
+            embeddings = self.asv_scale(embeddings)
+        else:
+            embeddings = self.asv_embeddings(asv_input)
         asv_embeddings, nucleotides = self._split_asvs(embeddings)
+        return asv_embeddings
 
+    def asv_gradient(
+        self, inputs: tf.Tensor, asv_embeddings
+    ) -> tuple[tf.Tensor, tf.Tensor]:
         asv_mask = float_mask(tf.reduce_sum(inputs, axis=-1, keepdims=True))
-        padded_asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
 
-        # padded embeddings are the skip connection
-        # normal asv embeddings continue through next block
-        padded_asv_embeddings = tf.pad(
-            asv_embeddings, [[0, 0], [1, 0], [0, 0]], constant_values=0
-        )
+        if self.add_token:
+            asv_mask = tf.pad(asv_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+            sample_embeddings = self._add_sample_token(asv_embeddings)
+        else:
+            sample_embeddings = asv_embeddings
 
-        sample_gated_embeddings = self._add_sample_token(asv_embeddings)
         sample_gated_embeddings = self.sample_encoder(
-            sample_gated_embeddings, mask=padded_asv_mask, training=training
-        )
-
-        sample_embeddings = (
-            padded_asv_embeddings + sample_gated_embeddings * self._base_alpha
+            sample_embeddings, mask=asv_mask, training=False
         )
-        return sample_embeddings, nucleotides
+        sample_embeddings = sample_embeddings + sample_gated_embeddings
+        return sample_embeddings
 
     def get_config(self):
         config = super(BaseSequenceEncoder, self).get_config()
@@ -156,6 +227,9 @@ class BaseSequenceEncoder(tf.keras.layers.Layer):
                 "nuc_attention_heads": self.nuc_attention_heads,
                 "nuc_attention_layers": self.nuc_attention_layers,
                 "nuc_intermediate_size": self.nuc_intermediate_size,
+                "is_16S": self.is_16S,
+                "vocab_size": self.vocab_size,
+                "add_token": self.add_token,
             }
         )
         return config
diff --git a/aam/models/conv_block.py b/aam/models/conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..915879c64bbf16b11f9024345923c640d4395b0b
--- /dev/null
+++ b/aam/models/conv_block.py
@@ -0,0 +1,124 @@
+import tensorflow as tf
+
+
+@tf.keras.saving.register_keras_serializable(package="ConvBlock")
+class ConvBlock(tf.keras.layers.Layer):
+    def __init__(
+        self,
+        activation="gelu",
+        use_bias=True,
+        norm_epsilon=1e-6,
+        dilation=1,
+        **kwargs,
+    ):
+        super(ConvBlock, self).__init__(**kwargs)
+        self._activation = activation
+        self._use_bias = use_bias
+        self._norm_epsilon = norm_epsilon
+        self._dilation = dilation
+
+        self.conv_norm = tf.keras.layers.LayerNormalization(epsilon=self._norm_epsilon)
+
+        self.ff_norm = tf.keras.layers.LayerNormalization(epsilon=self._norm_epsilon)
+
+    def build(self, input_shape):
+        seq_dim = input_shape[1]
+        emb_dim = input_shape[2]
+        self.conv1d = tf.keras.layers.Conv1D(
+            seq_dim,
+            kernel_size=3,
+            strides=1,
+            padding="same",
+            activation=self._activation,
+            dilation_rate=self._dilation,
+        )
+        self.ff = tf.keras.layers.Dense(
+            emb_dim, use_bias=self._use_bias, activation=self._activation
+        )
+
+    def get_config(self):
+        config = {
+            "activation": self._activation,
+            "use_bias": self._use_bias,
+            "norm_epsilon": self._norm_epsilon,
+            "dilation": self._dilation,
+        }
+        base_config = super(ConvBlock, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    def call(self, encoder_inputs, training=False):
+        """Return the output of the encoder.
+
+        Args:
+          encoder_inputs: A tensor with shape `(batch_size, input_length,
+            hidden_size)`.
+          attention_mask: A mask for the encoder self-attention layer with shape
+            `(batch_size, input_length, input_length)`.
+
+        Returns:
+          Output of encoder which is a `float32` tensor with shape
+            `(batch_size, input_length, hidden_size)`.
+        """
+        conv_inputs = tf.transpose(encoder_inputs, perm=[0, 2, 1])
+        conv_outputs = tf.transpose(self.conv1d(conv_inputs), perm=[0, 2, 1])
+        conv_outputs = encoder_inputs + self.conv_norm(conv_outputs)
+
+        ff_outputs = self.ff(conv_outputs)
+        output = conv_outputs + self.ff_norm(ff_outputs)
+        return output
+
+
+@tf.keras.saving.register_keras_serializable(package="ConvModule")
+class ConvModule(tf.keras.layers.Layer):
+    def __init__(
+        self,
+        layers,
+        activation="gelu",
+        use_bias=True,
+        norm_epsilon=1e-6,
+        **kwargs,
+    ):
+        super(ConvModule, self).__init__(**kwargs)
+        self.layers = layers
+        self._activation = activation
+        self._use_bias = use_bias
+        self._norm_epsilon = norm_epsilon
+
+        self.conv_blocks = []
+        dilation = 1
+        for i in range(self.layers):
+            self.conv_blocks.append(
+                ConvBlock(
+                    self._activation,
+                    self._use_bias,
+                    self._norm_epsilon,
+                    dilation=(2**i) * dilation,
+                )
+            )
+
+    def get_config(self):
+        config = {
+            "layers": self.layers,
+            "activation": self._activation,
+            "use_bias": self._use_bias,
+            "norm_epsilon": self._norm_epsilon,
+        }
+        base_config = super(ConvModule, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    def call(self, encoder_inputs, training=False):
+        """Return the output of the encoder.
+
+        Args:
+          encoder_inputs: A tensor with shape `(batch_size, input_length,
+            hidden_size)`.
+          attention_mask: A mask for the encoder self-attention layer with shape
+            `(batch_size, input_length, input_length)`.
+
+        Returns:
+          Output of encoder which is a `float32` tensor with shape
+            `(batch_size, input_length, hidden_size)`.
+        """
+        for layer_idx in range(self.layers):
+            encoder_inputs = self.conv_blocks[layer_idx](encoder_inputs)
+        return encoder_inputs
diff --git a/aam/models/sequence_encoder.py b/aam/models/sequence_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..53af1d1f67d55a7c248096c9fe8b4b3985f5d81c
--- /dev/null
+++ b/aam/models/sequence_encoder.py
@@ -0,0 +1,441 @@
+from __future__ import annotations
+
+from typing import Union
+
+import tensorflow as tf
+
+from aam.losses import PairwiseLoss
+from aam.models.base_sequence_encoder import BaseSequenceEncoder
+from aam.models.transformers import TransformerEncoder
+from aam.optimizers.gradient_accumulator import GradientAccumulator
+from aam.optimizers.loss_scaler import LossScaler
+from aam.utils import apply_random_mask, float_mask
+
+
+@tf.keras.saving.register_keras_serializable(package="SequenceEncoder")
+class SequenceEncoder(tf.keras.Model):
+    def __init__(
+        self,
+        output_dim: int,
+        token_limit: int,
+        encoder_type: str,
+        dropout_rate: float = 0.0,
+        embedding_dim: int = 128,
+        attention_heads: int = 4,
+        attention_layers: int = 4,
+        intermediate_size: int = 1024,
+        intermediate_activation: str = "gelu",
+        max_bp: int = 150,
+        is_16S: bool = True,
+        vocab_size: int = 6,
+        add_token: bool = True,
+        asv_dropout_rate: float = 0.0,
+        accumulation_steps: int = 1,
+        **kwargs,
+    ):
+        super(SequenceEncoder, self).__init__(**kwargs)
+        self.output_dim = output_dim
+        self.token_limit = token_limit
+        self.encoder_type = encoder_type
+        self.dropout_rate = dropout_rate
+        self.embedding_dim = embedding_dim
+        self.attention_heads = attention_heads
+        self.attention_layers = attention_layers
+        self.intermediate_size = intermediate_size
+        self.intermediate_activation = intermediate_activation
+        self.max_bp = max_bp
+        self.is_16S = is_16S
+        self.vocab_size = vocab_size
+        self.add_token = add_token
+        self.asv_dropout_rate = asv_dropout_rate
+        self.accumulation_steps = accumulation_steps
+
+        self._get_encoder_loss()
+        self.loss_tracker = tf.keras.metrics.Mean()
+        self.encoder_tracker = tf.keras.metrics.Mean()
+
+        # layers used in model
+        self.base_encoder = BaseSequenceEncoder(
+            self.embedding_dim,
+            self.max_bp,
+            self.token_limit,
+            sample_attention_heads=self.attention_heads,
+            sample_attention_layers=self.attention_layers,
+            sample_intermediate_size=self.intermediate_size,
+            dropout_rate=self.dropout_rate,
+            nuc_attention_heads=2,
+            nuc_attention_layers=2,
+            nuc_intermediate_size=128,
+            intermediate_activation=self.intermediate_activation,
+            is_16S=self.is_16S,
+            vocab_size=self.vocab_size,
+            add_token=self.add_token,
+            name="base_encoder",
+        )
+
+        self.encoder = TransformerEncoder(
+            num_layers=self.attention_layers,
+            num_attention_heads=self.attention_heads,
+            intermediate_size=intermediate_size,
+            dropout_rate=self.dropout_rate,
+            activation=self.intermediate_activation,
+            name="encoder",
+        )
+
+        if self.encoder_type == "combined":
+            uni_out, faith_out, tax_out = self.output_dim
+            self.uni_ff = tf.keras.Sequential(
+                [
+                    tf.keras.layers.Dense(
+                        self.embedding_dim, activation="gelu", dtype=tf.float32
+                    ),
+                    tf.keras.layers.Dense(uni_out, dtype=tf.float32),
+                ]
+            )
+            self.faith_ff = tf.keras.Sequential(
+                [
+                    tf.keras.layers.Dense(
+                        self.embedding_dim, activation="gelu", dtype=tf.float32
+                    ),
+                    tf.keras.layers.Dense(faith_out, dtype=tf.float32),
+                ]
+            )
+            self.tax_ff = tf.keras.Sequential(
+                [
+                    tf.keras.layers.Dense(
+                        self.embedding_dim, activation="gelu", dtype=tf.float32
+                    ),
+                    tf.keras.layers.Dense(tax_out, dtype=tf.float32),
+                ]
+            )
+        else:
+            self.encoder_ff = tf.keras.Sequential(
+                [
+                    tf.keras.layers.Dense(
+                        self.embedding_dim, activation="gelu", dtype=tf.float32
+                    ),
+                    tf.keras.layers.Dense(self.output_dim, dtype=tf.float32),
+                ]
+            )
+
+        self.gradient_accumulator = GradientAccumulator(self.accumulation_steps)
+        self.loss_scaler = LossScaler(self.gradient_accumulator.accum_steps)
+
+    def _get_encoder_loss(self):
+        if self.encoder_type == "combined":
+            self._unifrac_loss = PairwiseLoss()
+            self._tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none")
+            self.encoder_loss = self._compute_combined_loss
+            self.extract_encoder_pred = self._combined_embeddigns
+        elif self.encoder_type == "unifrac":
+            self._unifrac_loss = PairwiseLoss()
+            self.encoder_loss = self._compute_unifrac_loss
+            self.extract_encoder_pred = self._unifrac_embeddings
+        elif self.encoder_type == "faith_pd":
+            self._unifrac_loss = tf.keras.losses.MeanSquaredError(reduction="none")
+            self.encoder_loss = self._compute_unifrac_loss
+            self.extract_encoder_pred = self._unifrac_embeddings
+        elif self.encoder_type == "taxonomy":
+            self._tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none")
+            self.encoder_loss = self._compute_tax_loss
+            self.extract_encoder_pred = self._taxonomy_embeddings
+        else:
+            raise Exception(f"invalid encoder encoder_type: {self.encoder_type}")
+
+    def _combined_embeddigns(self, tensor, mask):
+        if self.add_token:
+            unifrac_pred = tensor[:, 0, :]
+        else:
+            mask = tf.cast(mask, dtype=tf.float32)
+            unifrac_pred = tf.reduce_sum(tensor * mask, axis=1)
+            unifrac_pred /= tf.reduce_sum(mask, axis=1)
+
+        if self.add_token:
+            faith_pred = tensor[:, 0, :]
+        else:
+            mask = tf.cast(mask, dtype=tf.float32)
+            faith_pred = tf.reduce_sum(tensor * mask, axis=1)
+            faith_pred /= tf.reduce_sum(mask, axis=1)
+
+        tax_pred = tensor
+        if self.add_token:
+            tax_pred = tax_pred[:, 1:, :]
+
+        return [
+            self.uni_ff(unifrac_pred),
+            self.faith_ff(faith_pred),
+            self.tax_ff(tax_pred),
+        ]
+
+    def _unifrac_embeddings(self, tensor, mask):
+        if self.add_token:
+            encoder_pred = tensor[:, 0, :]
+        else:
+            mask = tf.cast(mask, dtype=tf.float32)
+            encoder_pred = tf.reduce_sum(tensor * mask, axis=1)
+            encoder_pred /= tf.reduce_sum(mask, axis=1)
+        encoder_pred = self.encoder_ff(encoder_pred)
+        return encoder_pred
+
+    def _taxonomy_embeddings(self, tensor, mask):
+        tax_pred = tensor
+        if self.add_token:
+            tax_pred = tax_pred[:, 1:, :]
+        tax_pred = self.encoder_ff(tax_pred)
+        return tax_pred
+
+    def _compute_combined_loss(self, y_true, preds):
+        uni_true, faith_true, tax_true = y_true
+        uni_pred, faith_pred, tax_pred = preds
+
+        uni_loss = self._compute_unifrac_loss(uni_true, uni_pred)
+        faith_loss = tf.reduce_mean(tf.square(faith_true - faith_pred))
+        tax_loss = self._compute_tax_loss(tax_true, tax_pred)
+
+        return [uni_loss, faith_loss, tax_loss]
+
+    def _compute_tax_loss(
+        self,
+        tax_tokens: tf.Tensor,
+        tax_pred: tf.Tensor,
+    ) -> tf.Tensor:
+        if isinstance(self.output_dim, (list, tuple)):
+            out_dim = self.output_dim[-1]
+        else:
+            out_dim = self.output_dim
+        y_true = tf.reshape(tax_tokens, [-1])
+        y_pred = tf.reshape(tax_pred, [-1, out_dim])
+        y_pred = tf.keras.activations.softmax(y_pred, axis=-1)
+
+        mask = float_mask(y_true) > 0
+        y_true = tf.one_hot(y_true, depth=out_dim)
+
+        # smooth labels
+        y_true = y_true * 0.9 + 0.1
+
+        y_true = y_true[mask]
+        y_pred = y_pred[mask]
+
+        loss = tf.reduce_mean(self._tax_loss(y_true, y_pred))
+        return loss
+
+    def _compute_unifrac_loss(
+        self,
+        y_true: tf.Tensor,
+        unifrac_embeddings: tf.Tensor,
+    ) -> tf.Tensor:
+        loss = self._unifrac_loss(y_true, unifrac_embeddings)
+        if self.encoder_type == "unifrac":
+            batch = tf.cast(tf.shape(y_true)[0], dtype=tf.float32)
+            loss = tf.reduce_sum(loss, axis=1, keepdims=True) / (batch - 1)
+        return tf.reduce_mean(loss)
+
+    def _compute_encoder_loss(
+        self,
+        y_true: tf.Tensor,
+        encoder_embeddings: tf.Tensor,
+    ) -> tf.Tensor:
+        loss = self.encoder_loss(y_true, encoder_embeddings)
+        if self.encoder_type != "combined":
+            return tf.reduce_mean(loss)
+        else:
+            return loss
+
+    def _compute_loss(
+        self,
+        model_inputs: tuple[tf.Tensor, tf.Tensor],
+        y_true: Union[tf.Tensor, tuple[tf.Tensor, tf.Tensor]],
+        outputs: tuple[tf.Tensor, tf.Tensor, tf.Tensor],
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        nuc_tokens, counts = model_inputs
+        embeddings, encoder_embeddings = outputs
+        return self._compute_encoder_loss(y_true, encoder_embeddings)
+
+    def predict_step(
+        self,
+        data: Union[
+            tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor],
+            tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
+        ],
+    ):
+        inputs, y = data
+        embeddings, encoder_embeddings = self(inputs, training=False)
+
+        return encoder_embeddings
+
+    def train_step(
+        self,
+        data: Union[
+            tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor],
+            tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
+        ],
+    ):
+        if not self.gradient_accumulator.built:
+            self.gradient_accumulator.build(self.optimizer, self)
+
+        inputs, y = data
+        y_target, encoder_target = y
+        with tf.GradientTape() as tape:
+            outputs = self(inputs, training=True)
+            encoder_loss = self._compute_loss(inputs, encoder_target, outputs)
+            if self.encoder_type == "combined":
+                scaled_losses = self.loss_scaler(encoder_loss)
+            else:
+                scaled_losses = self.loss_scaler([encoder_loss])
+            loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0))
+
+        gradients = tape.gradient(
+            loss,
+            self.trainable_variables,
+            unconnected_gradients=tf.UnconnectedGradients.ZERO,
+        )
+        self.gradient_accumulator.apply_gradients(gradients)
+
+        self.loss_tracker.update_state(loss)
+        self.encoder_tracker.update_state(encoder_loss)
+        return {
+            "loss": self.loss_tracker.result(),
+            "encoder_loss": self.encoder_tracker.result(),
+            "learning_rate": self.optimizer.learning_rate,
+        }
+
+    def test_step(
+        self,
+        data: Union[
+            tuple[tuple[tf.Tensor, tf.Tensor], tf.Tensor],
+            tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
+        ],
+    ):
+        inputs, y = data
+        y_target, encoder_target = y
+        outputs = self(inputs, training=False)
+        encoder_loss = self._compute_loss(inputs, encoder_target, outputs)
+        scaled_losses = self.loss_scaler([encoder_loss])
+        loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0))
+
+        self.loss_tracker.update_state(loss)
+        self.encoder_tracker.update_state(encoder_loss)
+        return {
+            "loss": self.loss_tracker.result(),
+            "encoder_loss": self.encoder_tracker.result(),
+            "learning_rate": self.optimizer.learning_rate,
+        }
+
+    def call(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        random_mask = None
+        if training and self.asv_dropout_rate > 0:
+            random_mask = apply_random_mask(count_mask, self.asv_dropout_rate)
+
+        sample_embeddings = self.base_encoder(
+            tokens, random_mask=random_mask, training=training
+        )
+
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        encoder_gated_embeddings = self.encoder(
+            sample_embeddings, mask=count_attention_mask, training=training
+        )
+
+        encoder_pred = self.extract_encoder_pred(encoder_gated_embeddings, count_mask)
+        encoder_embeddings = sample_embeddings + encoder_gated_embeddings
+        return [encoder_embeddings, encoder_pred]
+
+    def base_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor]
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings = self.base_encoder.base_embeddings(tokens)
+
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        unifrac_gated_embeddings = self.unifrac_encoder(
+            sample_embeddings, mask=count_attention_mask
+        )
+        unifrac_pred = unifrac_gated_embeddings[:, 0, :]
+        unifrac_pred = self.unifrac_ff(unifrac_pred)
+
+        unifrac_embeddings = (
+            sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha
+        )
+
+        return unifrac_embeddings
+
+    def asv_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        asv_embeddings = self.base_encoder.asv_embeddings(tokens, training=training)
+
+        return asv_embeddings
+
+    def asv_gradient(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings)
+
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        unifrac_gated_embeddings = self.unifrac_encoder(
+            sample_embeddings, mask=count_attention_mask
+        )
+        unifrac_pred = unifrac_gated_embeddings[:, 0, :]
+        unifrac_pred = self.unifrac_ff(unifrac_pred)
+
+        unifrac_embeddings = (
+            sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha
+        )
+
+        return unifrac_embeddings
+
+    def get_config(self):
+        config = super(SequenceEncoder, self).get_config()
+        config.update(
+            {
+                "output_dim": self.output_dim,
+                "token_limit": self.token_limit,
+                "encoder_type": self.encoder_type,
+                "dropout_rate": self.dropout_rate,
+                "embedding_dim": self.embedding_dim,
+                "attention_heads": self.attention_heads,
+                "attention_layers": self.attention_layers,
+                "intermediate_size": self.intermediate_size,
+                "intermediate_activation": self.intermediate_activation,
+                "max_bp": self.max_bp,
+                "is_16S": self.is_16S,
+                "vocab_size": self.vocab_size,
+                "add_token": self.add_token,
+                "asv_dropout_rate": self.asv_dropout_rate,
+                "accumulation_steps": self.accumulation_steps,
+            }
+        )
+        return config
diff --git a/aam/models/sequence_regressor.py b/aam/models/sequence_regressor.py
index 0e5d9140ab441eacec85603c076c70e667e193dc..7c22d1f7411397fbe73d0dfe834b02e726425306 100644
--- a/aam/models/sequence_regressor.py
+++ b/aam/models/sequence_regressor.py
@@ -5,10 +5,12 @@ from typing import Optional, Union
 import tensorflow as tf
 import tensorflow_models as tfm
 
-from aam.models.taxonomy_encoder import TaxonomyEncoder
+# from aam.models.unifrac_encoder import UniFracEncoder
+from aam.models.sequence_encoder import SequenceEncoder
 from aam.models.transformers import TransformerEncoder
-from aam.models.unifrac_encoder import UniFracEncoder
-from aam.utils import float_mask, masked_loss
+from aam.optimizers.gradient_accumulator import GradientAccumulator
+from aam.optimizers.loss_scaler import LossScaler
+from aam.utils import float_mask
 
 
 @tf.keras.saving.register_keras_serializable(package="SequenceRegressor")
@@ -16,30 +18,36 @@ class SequenceRegressor(tf.keras.Model):
     def __init__(
         self,
         token_limit: int,
-        num_classes: Optional[int] = None,
+        base_output_dim: Optional[int] = None,
         shift: float = 0.0,
         scale: float = 1.0,
         dropout_rate: float = 0.0,
-        num_tax_levels: Optional[int] = None,
         embedding_dim: int = 128,
         attention_heads: int = 4,
         attention_layers: int = 4,
         intermediate_size: int = 1024,
         intermediate_activation: str = "relu",
-        base_model: Union[str, TaxonomyEncoder, UniFracEncoder] = "taxonomy",
+        base_model: str = "unifrac",
         freeze_base: bool = False,
         penalty: float = 1.0,
         nuc_penalty: float = 1.0,
         max_bp: int = 150,
+        is_16S: bool = True,
+        vocab_size: int = 6,
+        out_dim: int = 1,
+        classifier: bool = False,
+        add_token: bool = True,
+        class_weights: list = None,
+        asv_dropout_rate: float = 0.0,
+        accumulation_steps: int = 1,
         **kwargs,
     ):
         super(SequenceRegressor, self).__init__(**kwargs)
         self.token_limit = token_limit
-        self.num_classes = num_classes
+        self.base_output_dim = base_output_dim
         self.shift = shift
         self.scale = scale
         self.dropout_rate = dropout_rate
-        self.num_tax_levels = num_tax_levels
         self.embedding_dim = embedding_dim
         self.attention_heads = attention_heads
         self.attention_layers = attention_layers
@@ -49,66 +57,56 @@ class SequenceRegressor(tf.keras.Model):
         self.penalty = penalty
         self.nuc_penalty = nuc_penalty
         self.max_bp = max_bp
+        self.is_16S = is_16S
+        self.vocab_size = vocab_size
+        self.out_dim = out_dim
+        self.classifier = classifier
+        self.add_token = add_token
+        self.class_weights = class_weights
+        self.asv_dropout_rate = asv_dropout_rate
+        self.accumulation_steps = accumulation_steps
         self.loss_tracker = tf.keras.metrics.Mean()
 
         # layers used in model
+        self.combined_base = False
         if isinstance(base_model, str):
-            if base_model == "taxonomy":
-                self.base_model = TaxonomyEncoder(
-                    num_tax_levels=self.num_tax_levels,
-                    token_limit=self.token_limit,
-                    dropout_rate=self.dropout_rate,
-                    embedding_dim=self.embedding_dim,
-                    attention_heads=self.attention_heads,
-                    attention_layers=self.attention_layers,
-                    intermediate_size=self.intermediate_size,
-                    intermediate_activation=self.intermediate_activation,
-                    max_bp=self.max_bp,
-                )
-            elif base_model == "unifrac":
-                self.base_model = UniFracEncoder(
-                    self.token_limit,
-                    dropout_rate=self.dropout_rate,
-                    embedding_dim=self.embedding_dim,
-                    attention_heads=self.attention_heads,
-                    attention_layers=self.attention_layers,
-                    intermediate_size=self.intermediate_size,
-                    intermediate_activation=self.intermediate_activation,
-                    max_bp=self.max_bp,
-                )
-            else:
-                raise Exception("Invalid base model option.")
+            if base_model == "combined":
+                self.combined_base = True
+
+            self.base_model = SequenceEncoder(
+                output_dim=self.base_output_dim,
+                token_limit=self.token_limit,
+                encoder_type=base_model,
+                dropout_rate=self.dropout_rate,
+                embedding_dim=self.embedding_dim,
+                attention_heads=self.attention_heads,
+                attention_layers=self.attention_layers,
+                intermediate_size=self.intermediate_size,
+                intermediate_activation=self.intermediate_activation,
+                max_bp=self.max_bp,
+                is_16S=self.is_16S,
+                vocab_size=self.vocab_size,
+                add_token=self.add_token,
+                asv_dropout_rate=self.asv_dropout_rate,
+                accumulation_steps=self.accumulation_steps,
+            )
         else:
-            if not isinstance(base_model, (TaxonomyEncoder, UniFracEncoder)):
-                raise Exception(f"Unsupported base model of type {type(base_model)}")
             self.base_model = base_model
 
-        if isinstance(self.base_model, TaxonomyEncoder):
-            self.base_losses = {"base_loss": self.base_model._compute_tax_loss}
+        self.base_losses = {"base_loss": self.base_model._compute_encoder_loss}
+        if not self.combined_base:
             self.base_metrics = {
-                "base_loss": ("tax_entropy", self.base_model.tax_tracker)
+                "base_loss": ["encoder_loss", self.base_model.encoder_tracker]
             }
         else:
-            self.base_losses = {"base_loss": self.base_model._compute_unifrac_loss}
-            self.base_metrics = {
-                "base_loss": ["unifrac_mse", self.base_model.unifrac_tracker]
-            }
-
-        self.base_losses.update({"nuc_entropy": self.base_model._compute_nuc_loss})
-        self.base_metrics.update(
-            {"nuc_entropy": ["nuc_entropy", self.base_model.base_encoder.nuc_entropy]}
-        )
+            self.uni_tracker = tf.keras.metrics.Mean()
+            # self.faith_tracker = tf.keras.metrics.Mean()
+            self.tax_tracker = tf.keras.metrics.Mean()
 
         if self.freeze_base:
             print("Freezing base model...")
             self.base_model.trainable = False
 
-        self._count_alpha = self.add_weight(
-            name="count_alpha",
-            initializer=tf.keras.initializers.Zeros(),
-            trainable=True,
-            dtype=tf.float32,
-        )
         self.count_encoder = TransformerEncoder(
             num_layers=self.attention_layers,
             num_attention_heads=self.attention_heads,
@@ -119,8 +117,14 @@ class SequenceRegressor(tf.keras.Model):
         self.count_pos = tfm.nlp.layers.PositionEmbedding(
             self.token_limit + 5, dtype=tf.float32
         )
-        self.count_out = tf.keras.layers.Dense(1, use_bias=False, dtype=tf.float32)
-        self.count_activation = tf.keras.layers.Activation("linear", dtype=tf.float32)
+        self.count_out = tf.keras.Sequential(
+            [
+                tf.keras.layers.Dense(
+                    self.embedding_dim, activation="gelu", dtype=tf.float32
+                ),
+                tf.keras.layers.Dense(1, dtype=tf.float32),
+            ]
+        )
         self.count_loss = tf.keras.losses.MeanSquaredError(reduction="none")
         self.count_tracker = tf.keras.metrics.Mean()
 
@@ -131,15 +135,28 @@ class SequenceRegressor(tf.keras.Model):
             dropout_rate=self.dropout_rate,
             activation=self.intermediate_activation,
         )
+
         self.target_tracker = tf.keras.metrics.Mean()
+        if not self.classifier:
+            self.metric_tracker = tf.keras.metrics.MeanAbsoluteError()
+            self.metric_string = "mae"
+        else:
+            self.metric_tracker = tf.keras.metrics.SparseCategoricalAccuracy()
+            self.metric_string = "accuracy"
+        self.target_ff = tf.keras.Sequential(
+            [
+                tf.keras.layers.Dense(
+                    self.embedding_dim, activation="gelu", dtype=tf.float32
+                ),
+                tf.keras.layers.Dense(self.out_dim, dtype=tf.float32),
+            ]
+        )
 
-        self.target_ff = tf.keras.layers.Dense(1, use_bias=False, dtype=tf.float32)
-        self.metric_tracker = tf.keras.metrics.MeanAbsoluteError()
-        self.metric_string = "mae"
-        self.target_activation = tf.keras.layers.Activation("linear", dtype=tf.float32)
         self.loss_metrics = sorted(
             ["loss", "target_loss", "count_mse", self.metric_string]
         )
+        self.gradient_accumulator = GradientAccumulator(self.accumulation_steps)
+        self.loss_scaler = LossScaler(self.gradient_accumulator.accum_steps)
 
     def evaluate_metric(self, dataset, metric, **kwargs):
         metric_index = self.loss_metrics.index(metric)
@@ -147,15 +164,27 @@ class SequenceRegressor(tf.keras.Model):
         return evaluated_metrics[metric_index]
 
     def _compute_target_loss(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
-        return tf.reduce_mean(self.loss(y_true, y_pred))
+        if not self.classifier:
+            loss = tf.square(y_true - y_pred)
+            return tf.reduce_mean(loss)
+
+        y_true = tf.cast(y_true, dtype=tf.int32)
+        y_true = tf.reshape(y_true, shape=[-1])
+        weights = tf.gather(self.class_weights, y_true)
+        y_true = tf.one_hot(y_true, self.out_dim)
+        y_pred = tf.reshape(y_pred, shape=[-1, self.out_dim])
+        y_pred = tf.keras.activations.softmax(y_pred, axis=-1)
+        loss = self.loss(y_true, y_pred)  # * weights
+        return tf.reduce_mean(loss)
 
-    @masked_loss(sparse_cat=False)
     def _compute_count_loss(
         self, counts: tf.Tensor, count_pred: tf.Tensor
     ) -> tf.Tensor:
         relative_counts = self._relative_abundance(counts)
         loss = tf.square(relative_counts - count_pred)
-        return tf.squeeze(loss, axis=-1)
+        mask = float_mask(counts)
+        loss = tf.reduce_sum(loss * mask, axis=1) / tf.reduce_sum(mask, axis=1)
+        return tf.reduce_mean(loss)
 
     def _compute_loss(
         self,
@@ -165,36 +194,26 @@ class SequenceRegressor(tf.keras.Model):
             tuple[tf.Tensor, tf.Tensor, tf.Tensor],
             tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
         ],
-        sample_weights: Optional[tf.Tensor] = None,
+        train_step: bool = True,
     ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
         nuc_tokens, counts = model_inputs
         y_target, base_target = y_true
 
-        target_embeddings, count_pred, y_pred, base_pred, nuc_pred = outputs
-
+        target_embeddings, count_pred, y_pred, base_pred = outputs
         target_loss = self._compute_target_loss(y_target, y_pred)
         count_loss = self._compute_count_loss(counts, count_pred)
 
-        loss = target_loss + count_loss
+        base_loss = 0
         if not self.freeze_base:
-            base_loss = (
-                self.base_losses["base_loss"](base_target, base_pred) * self.penalty
-            )
-            nuc_loss = (
-                self.base_losses["nuc_entropy"](nuc_tokens, nuc_pred) * self.nuc_penalty
-            )
-            loss = loss + nuc_loss + base_loss
-        else:
-            base_loss = 0
-            nuc_loss = 0
+            if self.combined_base:
+                uni_loss, faith_loss, tax_loss = self.base_losses["base_loss"](
+                    base_target, base_pred
+                )
+                return (target_loss, count_loss, uni_loss, faith_loss, tax_loss)
+            else:
+                base_loss = self.base_losses["base_loss"](base_target, base_pred)
 
-        return (
-            loss,
-            target_loss,
-            count_loss,
-            base_loss,
-            nuc_loss,
-        )
+        return (target_loss, count_loss, base_loss)
 
     def _compute_metric(
         self,
@@ -206,10 +225,15 @@ class SequenceRegressor(tf.keras.Model):
     ):
         y_true, base_target = y_true
 
-        target_embeddings, count_pred, y_pred, base_pred, nuc_pred = outputs
-        y_true = y_true * self.scale + self.shift
-        y_pred = y_pred * self.scale + self.shift
-        self.metric_tracker.update_state(y_true, y_pred)
+        target_embeddings, count_pred, y_pred, base_pred = outputs
+        if not self.classifier:
+            y_true = y_true * self.scale + self.shift
+            y_pred = y_pred * self.scale + self.shift
+            self.metric_tracker.update_state(y_true, y_pred)
+        else:
+            y_true = tf.cast(y_true, dtype=tf.int32)
+            y_pred = tf.keras.activations.softmax(y_pred)
+            self.metric_tracker.update_state(y_true, y_pred)
 
     def predict_step(
         self,
@@ -219,12 +243,14 @@ class SequenceRegressor(tf.keras.Model):
         ],
     ):
         inputs, (y_true, _) = data
-        target_embeddings, count_pred, y_pred, base_pred, nuc_pred = self(
-            inputs, training=False
-        )
+        target_embeddings, count_pred, y_pred, base_pred = self(inputs, training=False)
 
-        y_true = y_true * self.scale + self.shift
-        y_pred = y_pred * self.scale + self.shift
+        if not self.classifier:
+            y_true = y_true * self.scale + self.shift
+            y_pred = y_pred * self.scale + self.shift
+        else:
+            y_true = tf.cast(y_true, dtype=tf.int32)
+            y_pred = tf.argmax(tf.keras.activations.softmax(y_pred), axis=-1)
         return y_pred, y_true
 
     def train_step(
@@ -234,33 +260,64 @@ class SequenceRegressor(tf.keras.Model):
             tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
         ],
     ):
+        if not self.gradient_accumulator.built:
+            self.gradient_accumulator.build(self.optimizer, self)
+
         inputs, y = data
 
         with tf.GradientTape() as tape:
             outputs = self(inputs, training=True)
-            loss, target_loss, count_mse, base_loss, nuc_loss = self._compute_loss(
-                inputs, y, outputs
-            )
-
-        gradients = tape.gradient(loss, self.trainable_variables)
-        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+            if self.combined_base:
+                target_loss, count_mse, uni_loss, faith_loss, tax_loss = (
+                    self._compute_loss(inputs, y, outputs, train_step=True)
+                )
+                scaled_losses = self.loss_scaler(
+                    [target_loss, count_mse, uni_loss, tax_loss]
+                )
+            else:
+                target_loss, count_mse, base_loss = self._compute_loss(
+                    inputs, y, outputs, train_step=True
+                )
+                scaled_losses = self.loss_scaler([target_loss, count_mse, base_loss])
+            loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0))
 
+        gradients = tape.gradient(
+            loss,
+            self.trainable_variables,
+            unconnected_gradients=tf.UnconnectedGradients.ZERO,
+        )
+        self.gradient_accumulator.apply_gradients(gradients)
         self.loss_tracker.update_state(loss)
         self.target_tracker.update_state(target_loss)
         self.count_tracker.update_state(count_mse)
-        base_loss_key, base_loss_metric = self.base_metrics["base_loss"]
-        base_loss_metric.update_state(base_loss)
-        nuc_entropy_key, nuc_entropy_metric = self.base_metrics["nuc_entropy"]
-        nuc_entropy_metric.update_state(nuc_loss)
+
         self._compute_metric(y, outputs)
-        return {
-            "loss": self.loss_tracker.result(),
-            "target_loss": self.target_tracker.result(),
-            "count_mse": self.count_tracker.result(),
-            base_loss_key: base_loss_metric.result(),
-            nuc_entropy_key: nuc_entropy_metric.result(),
-            self.metric_string: self.metric_tracker.result(),
-        }
+        if self.combined_base:
+            self.uni_tracker.update_state(uni_loss)
+            # self.faith_tracker.update_state(faith_loss)
+            self.tax_tracker.update_state(tax_loss)
+            return {
+                "loss": self.loss_tracker.result(),
+                "target_loss": self.target_tracker.result(),
+                "count_mse": self.count_tracker.result(),
+                # base_loss_key: base_loss_metric.result(),
+                "uni_loss": self.uni_tracker.result(),
+                # "faith_loss": self.faith_tracker.result(),
+                "tax_loss": self.tax_tracker.result(),
+                self.metric_string: self.metric_tracker.result(),
+                "learning_rate": self.optimizer.learning_rate,
+            }
+        else:
+            base_loss_key, base_loss_metric = self.base_metrics["base_loss"]
+            base_loss_metric.update_state(base_loss)
+            return {
+                "loss": self.loss_tracker.result(),
+                "target_loss": self.target_tracker.result(),
+                "count_mse": self.count_tracker.result(),
+                base_loss_key: base_loss_metric.result(),
+                self.metric_string: self.metric_tracker.result(),
+                "learning_rate": self.optimizer.learning_rate,
+            }
 
     def test_step(
         self,
@@ -271,27 +328,52 @@ class SequenceRegressor(tf.keras.Model):
     ):
         inputs, y = data
 
-        outputs = self(inputs, training=False)
-        loss, target_loss, count_mse, base_loss, nuc_loss = self._compute_loss(
-            inputs, y, outputs
-        )
+        outputs = self(inputs, training=True)
+        if self.combined_base:
+            target_loss, count_mse, uni_loss, faith_loss, tax_loss = self._compute_loss(
+                inputs, y, outputs, train_step=True
+            )
+            scaled_losses = self.loss_scaler(
+                [target_loss, count_mse, uni_loss, tax_loss]
+            )
+        else:
+            target_loss, count_mse, base_loss = self._compute_loss(
+                inputs, y, outputs, train_step=True
+            )
+            scaled_losses = self.loss_scaler([target_loss, count_mse, base_loss])
+        loss = tf.reduce_mean(tf.stack(scaled_losses, axis=0))
 
         self.loss_tracker.update_state(loss)
         self.target_tracker.update_state(target_loss)
         self.count_tracker.update_state(count_mse)
-        base_loss_key, base_loss_metric = self.base_metrics["base_loss"]
-        base_loss_metric.update_state(base_loss)
-        nuc_entropy_key, nuc_entropy_metric = self.base_metrics["nuc_entropy"]
-        nuc_entropy_metric.update_state(nuc_loss)
+
         self._compute_metric(y, outputs)
-        return {
-            "loss": self.loss_tracker.result(),
-            "target_loss": self.target_tracker.result(),
-            "count_mse": self.count_tracker.result(),
-            base_loss_key: base_loss_metric.result(),
-            nuc_entropy_key: nuc_entropy_metric.result(),
-            self.metric_string: self.metric_tracker.result(),
-        }
+        if self.combined_base:
+            self.uni_tracker.update_state(uni_loss)
+            # self.faith_tracker.update_state(faith_loss)
+            self.tax_tracker.update_state(tax_loss)
+            return {
+                "loss": self.loss_tracker.result(),
+                "target_loss": self.target_tracker.result(),
+                "count_mse": self.count_tracker.result(),
+                # base_loss_key: base_loss_metric.result(),
+                "uni_loss": self.uni_tracker.result(),
+                # "faith_loss": self.faith_tracker.result(),
+                "tax_loss": self.tax_tracker.result(),
+                self.metric_string: self.metric_tracker.result(),
+                "learning_rate": self.optimizer.learning_rate,
+            }
+        else:
+            base_loss_key, base_loss_metric = self.base_metrics["base_loss"]
+            base_loss_metric.update_state(base_loss)
+            return {
+                "loss": self.loss_tracker.result(),
+                "target_loss": self.target_tracker.result(),
+                "count_mse": self.count_tracker.result(),
+                base_loss_key: base_loss_metric.result(),
+                self.metric_string: self.metric_tracker.result(),
+                "learning_rate": self.optimizer.learning_rate,
+            }
 
     def _relative_abundance(self, counts: tf.Tensor) -> tf.Tensor:
         counts = tf.cast(counts, dtype=tf.float32)
@@ -306,11 +388,14 @@ class SequenceRegressor(tf.keras.Model):
         attention_mask: Optional[tf.Tensor] = None,
         training: bool = False,
     ) -> tf.Tensor:
-        count_embeddings = tensor + self.count_pos(tensor) * relative_abundances
+        count_embeddings = (tensor + self.count_pos(tensor)) * relative_abundances
         count_embeddings = self.count_encoder(
             count_embeddings, mask=attention_mask, training=training
         )
-        count_pred = count_embeddings[:, 1:, :]
+        count_pred = count_embeddings
+        if self.add_token:
+            count_pred = count_pred[:, 1:, :]
+
         count_pred = self.count_out(count_pred)
         return count_embeddings, count_pred
 
@@ -323,7 +408,13 @@ class SequenceRegressor(tf.keras.Model):
         target_embeddings = self.target_encoder(
             tensor, mask=attention_mask, training=training
         )
-        target_out = target_embeddings[:, 0, :]
+        if self.add_token:
+            target_out = target_embeddings[:, 0, :]
+        else:
+            mask = tf.cast(attention_mask, dtype=tf.float32)
+            target_out = tf.reduce_sum(target_embeddings * mask, axis=1)
+            target_out /= tf.reduce_sum(mask, axis=1)
+
         target_out = self.target_ff(target_out)
         return target_embeddings, target_out
 
@@ -342,12 +433,13 @@ class SequenceRegressor(tf.keras.Model):
         rel_abundance = self._relative_abundance(counts)
 
         # account for <SAMPLE> token
-        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
-        rel_abundance = tf.pad(
-            rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
-        )
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+            rel_abundance = tf.pad(
+                rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
+            )
         count_attention_mask = count_mask
-        base_embeddings, base_pred, nuc_embeddings = self.base_model(
+        base_embeddings, base_pred = self.base_model(
             (tokens, counts), training=training
         )
 
@@ -357,40 +449,161 @@ class SequenceRegressor(tf.keras.Model):
             attention_mask=count_attention_mask,
             training=training,
         )
-        count_embeddings = base_embeddings + count_gated_embeddings * self._count_alpha
+        count_embeddings = base_embeddings + count_gated_embeddings
+        # count_embeddings = count_gated_embeddings
 
         target_embeddings, target_out = self._compute_target_embeddings(
             count_embeddings, attention_mask=count_attention_mask, training=training
         )
 
-        return (
-            target_embeddings,
-            self.count_activation(count_pred),
-            self.target_activation(target_out),
-            base_pred,
-            nuc_embeddings,
+        return (target_embeddings, count_pred, target_out, base_pred)
+
+    def base_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor]
+    ) -> Union[
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor],
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
+    ]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        count_mask = float_mask(counts, dtype=tf.int32)
+        rel_abundance = self._relative_abundance(counts)
+
+        # account for <SAMPLE> token
+        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        rel_abundance = tf.pad(
+            rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
+        )
+        base_embeddings = self.base_model.base_embeddings((tokens, counts))
+
+        return base_embeddings
+
+    def base_gradient(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], base_embeddings
+    ) -> Union[
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor],
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
+    ]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        count_mask = float_mask(counts, dtype=tf.int32)
+        rel_abundance = self._relative_abundance(counts)
+
+        # account for <SAMPLE> token
+        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        rel_abundance = tf.pad(
+            rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
+        )
+        count_attention_mask = count_mask
+
+        count_gated_embeddings, count_pred = self._compute_count_embeddings(
+            base_embeddings,
+            rel_abundance,
+            attention_mask=count_attention_mask,
+        )
+        # count_embeddings = base_embeddings + count_gated_embeddings * self._count_alpha
+        count_embeddings = count_gated_embeddings
+
+        target_embeddings, target_out = self._compute_target_embeddings(
+            count_embeddings, attention_mask=count_attention_mask
+        )
+
+        return self.target_activation(target_out)
+
+    def asv_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> Union[
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor],
+        tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
+    ]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        count_mask = float_mask(counts, dtype=tf.int32)
+        rel_abundance = self._relative_abundance(counts)
+
+        # account for <SAMPLE> token
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+            rel_abundance = tf.pad(
+                rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
+            )
+        asv_embeddings = self.base_model.asv_embeddings(
+            (tokens, counts), training=False
+        )
+
+        return asv_embeddings
+
+    def asv_gradient(self, inputs, asv_embeddings):
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        count_mask = float_mask(counts, dtype=tf.int32)
+        rel_abundance = self._relative_abundance(counts)
+
+        # account for <SAMPLE> token
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+            rel_abundance = tf.pad(
+                rel_abundance, [[0, 0], [1, 0], [0, 0]], constant_values=1
+            )
+        count_attention_mask = count_mask
+        base_embeddings = self.base_model.asv_gradient(
+            (tokens, counts), asv_embeddings=asv_embeddings
+        )
+
+        count_gated_embeddings, count_pred = self._compute_count_embeddings(
+            base_embeddings,
+            rel_abundance,
+            attention_mask=count_attention_mask,
+            training=False,
+        )
+        # count_embeddings = base_embeddings + count_gated_embeddings
+        count_embeddings = count_gated_embeddings
+
+        target_embeddings, target_out = self._compute_target_embeddings(
+            count_embeddings, attention_mask=count_attention_mask, training=False
         )
 
+        return self.target_activation(target_out)
+
     def get_config(self):
         config = super(SequenceRegressor, self).get_config()
         config.update(
             {
                 "token_limit": self.token_limit,
-                "base_model": tf.keras.saving.serialize_keras_object(self.base_model),
-                "num_classes": self.num_classes,
+                "base_output_dim": self.base_output_dim,
                 "shift": self.shift,
                 "scale": self.scale,
                 "dropout_rate": self.dropout_rate,
-                "num_tax_levels": self.num_tax_levels,
                 "embedding_dim": self.embedding_dim,
                 "attention_heads": self.attention_heads,
                 "attention_layers": self.attention_layers,
                 "intermediate_size": self.intermediate_size,
                 "intermediate_activation": self.intermediate_activation,
+                "base_model": tf.keras.saving.serialize_keras_object(self.base_model),
                 "freeze_base": self.freeze_base,
                 "penalty": self.penalty,
                 "nuc_penalty": self.nuc_penalty,
                 "max_bp": self.max_bp,
+                "is_16S": self.is_16S,
+                "vocab_size": self.vocab_size,
+                "out_dim": self.out_dim,
+                "classifier": self.classifier,
+                "add_token": self.add_token,
+                "class_weights": self.class_weights,
+                "asv_dropout_rate": self.asv_dropout_rate,
+                "accumulation_steps": self.accumulation_steps,
             }
         )
         return config
diff --git a/aam/models/taxonomy_encoder.py b/aam/models/taxonomy_encoder.py
index a83cfc265af9fab40bac125d18008b4833e9de2c..1ba4d27d206e6db8349b8eac48375f605eafce94 100644
--- a/aam/models/taxonomy_encoder.py
+++ b/aam/models/taxonomy_encoder.py
@@ -6,7 +6,9 @@ import tensorflow as tf
 
 from aam.models.base_sequence_encoder import BaseSequenceEncoder
 from aam.models.transformers import TransformerEncoder
-from aam.utils import float_mask, masked_loss
+from aam.optimizers.gradient_accumulator import GradientAccumulator
+from aam.optimizers.loss_scaler import LossScaler
+from aam.utils import apply_random_mask, float_mask
 
 
 @tf.keras.saving.register_keras_serializable(package="TaxonomyEncoder")
@@ -22,6 +24,12 @@ class TaxonomyEncoder(tf.keras.Model):
         intermediate_size: int = 1024,
         intermediate_activation: str = "gelu",
         max_bp: int = 150,
+        include_alpha: bool = True,
+        is_16S: bool = True,
+        vocab_size: int = 6,
+        add_token: bool = True,
+        asv_dropout_rate: float = 0.0,
+        accumulation_steps: int = 1,
         **kwargs,
     ):
         super(TaxonomyEncoder, self).__init__(**kwargs)
@@ -35,12 +43,16 @@ class TaxonomyEncoder(tf.keras.Model):
         self.intermediate_size = intermediate_size
         self.intermediate_activation = intermediate_activation
         self.max_bp = max_bp
-
+        self.include_alpha = include_alpha
+        self.is_16S = is_16S
+        self.vocab_size = vocab_size
+        self.add_token = add_token
+        self.asv_dropout_rate = asv_dropout_rate
+        self.accumulation_steps = accumulation_steps
         self.loss_tracker = tf.keras.metrics.Mean()
-        self.tax_loss = tf.keras.losses.SparseCategoricalCrossentropy(
-            ignore_class=0, from_logits=True, reduction="none"
-        )
-        self.tax_tracker = tf.keras.metrics.Mean()
+        # self.tax_loss = tf.keras.losses.CategoricalFocalCrossentropy(reduction="none")
+        self.tax_loss = tf.keras.losses.CategoricalCrossentropy(reduction="none")
+        self.encoder_tracker = tf.keras.metrics.Mean()
 
         # layers used in model
         self.base_encoder = BaseSequenceEncoder(
@@ -55,16 +67,12 @@ class TaxonomyEncoder(tf.keras.Model):
             nuc_attention_layers=3,
             nuc_intermediate_size=128,
             intermediate_activation=self.intermediate_activation,
+            is_16S=self.is_16S,
+            vocab_size=self.vocab_size,
+            add_token=self.add_token,
             name="base_encoder",
         )
 
-        self._tax_alpha = self.add_weight(
-            name="tax_alpha",
-            initializer=tf.keras.initializers.Zeros(),
-            trainable=True,
-            dtype=tf.float32,
-        )
-
         self.tax_encoder = TransformerEncoder(
             num_layers=self.attention_layers,
             num_attention_heads=self.attention_heads,
@@ -75,13 +83,12 @@ class TaxonomyEncoder(tf.keras.Model):
         )
 
         self.tax_level_logits = tf.keras.layers.Dense(
-            self.num_tax_levels,
-            use_bias=False,
-            dtype=tf.float32,
-            name="tax_level_logits",
+            self.num_tax_levels, dtype=tf.float32
         )
 
         self.loss_metrics = sorted(["loss", "target_loss", "count_mse"])
+        self.gradient_accumulator = GradientAccumulator(self.accumulation_steps)
+        self.loss_scaler = LossScaler()
 
     def evaluate_metric(self, dataset, metric, **kwargs):
         metric_index = self.loss_metrics.index(metric)
@@ -91,15 +98,22 @@ class TaxonomyEncoder(tf.keras.Model):
     def _compute_nuc_loss(self, nuc_tokens, nuc_pred):
         return self.base_encoder._compute_nuc_loss(nuc_tokens, nuc_pred)
 
-    @masked_loss(sparse_cat=True)
-    def _compute_tax_loss(
+    def _compute_encoder_loss(
         self,
         tax_tokens: tf.Tensor,
         tax_pred: tf.Tensor,
         sample_weights: Optional[tf.Tensor] = None,
     ) -> tf.Tensor:
-        loss = self.tax_loss(tax_tokens, tax_pred)
+        y_true = tf.reshape(tax_tokens, [-1])
+        y_pred = tf.reshape(tax_pred, [-1, self.num_tax_levels])
+        y_pred = tf.keras.activations.softmax(y_pred, axis=-1)
 
+        mask = float_mask(y_true) > 0
+        y_true = tf.one_hot(y_true, depth=self.num_tax_levels)
+        y_true = y_true[mask]
+        y_pred = y_pred[mask]
+
+        loss = tf.reduce_mean(self.tax_loss(y_true, y_pred))
         return loss
 
     def _compute_loss(
@@ -111,14 +125,9 @@ class TaxonomyEncoder(tf.keras.Model):
     ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
         nuc_tokens, counts = model_inputs
         taxonomy_embeddings, tax_pred, nuc_pred = outputs
-        tax_loss = self._compute_tax_loss(
-            y_true, tax_pred, sample_weights=sample_weights
-        )
-
-        nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred)
-
+        tax_loss = self._compute_encoder_loss(y_true, tax_pred)
+        nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) * 0.0
         loss = tax_loss + nuc_loss
-
         return (loss, tax_loss, nuc_loss)
 
     def predict_step(
@@ -140,22 +149,32 @@ class TaxonomyEncoder(tf.keras.Model):
             tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
         ],
     ):
-        inputs, y = data
+        if not self.gradient_accumulator.built:
+            self.gradient_accumulator.build(self.optimizer, self)
 
+        inputs, y = data
+        y_target, tax_target = y
         with tf.GradientTape() as tape:
             outputs = self(inputs, training=True)
-            loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs)
-
-        gradients = tape.gradient(loss, self.trainable_variables)
-        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+            loss, tax_loss, nuc_loss = self._compute_loss(inputs, tax_target, outputs)
+            scaled_losses = self.loss_scaler([tax_loss])
+            loss = tf.reduce_sum(tf.stack(scaled_losses, axis=0))
+
+        gradients = tape.gradient(
+            loss,
+            self.trainable_variables,
+            unconnected_gradients=tf.UnconnectedGradients.ZERO,
+        )
+        self.gradient_accumulator.apply_gradients(gradients)
 
         self.loss_tracker.update_state(loss)
-        self.tax_tracker.update_state(tax_loss)
+        self.encoder_tracker.update_state(tax_loss)
         self.base_encoder.nuc_entropy.update_state(nuc_loss)
         return {
             "loss": self.loss_tracker.result(),
-            "tax_entropy": self.tax_tracker.result(),
+            "tax_entropy": self.encoder_tracker.result(),
             "nuc_entropy": self.base_encoder.nuc_entropy.result(),
+            "learning_rate": self.optimizer.learning_rate,
         }
 
     def test_step(
@@ -166,17 +185,18 @@ class TaxonomyEncoder(tf.keras.Model):
         ],
     ):
         inputs, y = data
-
+        y_target, tax_target = y
         outputs = self(inputs, training=False)
-        loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs)
+        loss, tax_loss, nuc_loss = self._compute_loss(inputs, tax_target, outputs)
 
         self.loss_tracker.update_state(loss)
-        self.tax_tracker.update_state(tax_loss)
+        self.encoder_tracker.update_state(tax_loss)
         self.base_encoder.nuc_entropy.update_state(nuc_loss)
         return {
             "loss": self.loss_tracker.result(),
-            "tax_entropy": self.tax_tracker.result(),
+            "tax_entropy": self.encoder_tracker.result(),
             "nuc_entropy": self.base_encoder.nuc_entropy.result(),
+            "learning_rate": self.optimizer.learning_rate,
         }
 
     def call(
@@ -187,23 +207,100 @@ class TaxonomyEncoder(tf.keras.Model):
         tokens = tf.cast(tokens, dtype=tf.int32)
         counts = tf.cast(counts, dtype=tf.int32)
 
-        sample_embeddings, nuc_embeddings = self.base_encoder(tokens, training=training)
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        random_mask = None
+        if training and self.asv_dropout_rate > 0:
+            random_mask = apply_random_mask(count_mask, self.asv_dropout_rate)
+
+        sample_embeddings, nuc_embeddings = self.base_encoder(
+            tokens, random_mask=random_mask, training=training
+        )
+
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        tax_gated_embeddings = self.tax_encoder(
+            sample_embeddings, mask=count_attention_mask, training=training
+        )
+        tax_pred = tax_gated_embeddings
+        if self.add_token:
+            tax_pred = tax_pred[:, 1:, :]
+        tax_pred = self.tax_level_logits(tax_pred)
+
+        tax_embeddings = sample_embeddings + tax_gated_embeddings
+        return [tax_embeddings, tax_pred, nuc_embeddings]
+
+    def base_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings, nuc_embeddings, asv_embeddings = (
+            self.base_encoder.base_embeddings(tokens, training=training)
+        )
 
         # account for <SAMPLE> token
         count_mask = float_mask(counts, dtype=tf.int32)
         count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
         count_attention_mask = count_mask
 
-        tax_gated_embeddings = self.tax_encoder(
+        unifrac_gated_embeddings = self.unifrac_encoder(
             sample_embeddings, mask=count_attention_mask, training=training
         )
+        unifrac_pred = unifrac_gated_embeddings[:, 0, :]
+        unifrac_pred = self.unifrac_ff(unifrac_pred)
+
+        unifrac_embeddings = (
+            sample_embeddings + unifrac_gated_embeddings * self._tax_alpha
+        )
+
+        return [unifrac_embeddings, unifrac_pred, nuc_embeddings, asv_embeddings]
+
+    def asv_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        asv_embeddings = self.base_encoder.get_asv_embeddings(tokens, training=training)
 
-        tax_pred = tax_gated_embeddings[:, 1:, :]
+        return asv_embeddings
+
+    def asv_gradient(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings)
+
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        tax_gated_embeddings = self.tax_encoder(
+            sample_embeddings, mask=count_attention_mask, training=False
+        )
+        tax_pred = tax_gated_embeddings
+        if self.add_token:
+            tax_pred = tax_pred[:, 1:, :]
         tax_pred = self.tax_level_logits(tax_pred)
 
-        tax_embeddings = sample_embeddings + tax_gated_embeddings * self._tax_alpha
+        # tax_embeddings = sample_embeddings + tax_gated_embeddings
+        tax_embeddings = tax_gated_embeddings
 
-        return [tax_embeddings, tax_pred, nuc_embeddings]
+        return tax_embeddings
 
     def get_config(self):
         config = super(TaxonomyEncoder, self).get_config()
@@ -218,6 +315,11 @@ class TaxonomyEncoder(tf.keras.Model):
                 "intermediate_size": self.intermediate_size,
                 "intermediate_activation": self.intermediate_activation,
                 "max_bp": self.max_bp,
+                "is_16S": self.is_16S,
+                "vocab_size": self.vocab_size,
+                "add_token": self.add_token,
+                "asv_dropout_rate": self.asv_dropout_rate,
+                "accumulation_steps": self.accumulation_steps,
             }
         )
         return config
diff --git a/aam/models/transformers.py b/aam/models/transformers.py
index a48abf7402b123863c041cd44ac7bd26f7d3ddb2..29537684dcb81fbcc2e10ae33618b3948e43a020 100644
--- a/aam/models/transformers.py
+++ b/aam/models/transformers.py
@@ -14,6 +14,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
         use_bias=False,
         norm_first=True,
         norm_epsilon=1e-6,
+        use_layer_norm=True,
+        share_rezero=True,
         **kwargs,
     ):
         super(TransformerEncoder, self).__init__(**kwargs)
@@ -39,10 +41,12 @@ class TransformerEncoder(tf.keras.layers.Layer):
                     inner_activation=self._activation,
                     dropout_rate=self._dropout_rate,
                     attention_dropout_rate=self._dropout_rate,
+                    use_layer_norm=False,
+                    share_rezero=True,
                     name=("layer_%d" % i),
                 )
             )
-        self.output_normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6)
+        self.output_normalization = tf.keras.layers.LayerNormalization()
         super(TransformerEncoder, self).build(input_shape)
 
     def get_config(self):
@@ -76,10 +80,10 @@ class TransformerEncoder(tf.keras.layers.Layer):
         attention_mask = mask
         if attention_mask is not None:
             attention_mask = tf.matmul(attention_mask, attention_mask, transpose_b=True)
-
+        encoder_inputs = encoder_inputs
         for layer_idx in range(self.num_layers):
             encoder_inputs = self.encoder_layers[layer_idx](
                 [encoder_inputs, attention_mask], training=training
             )
-        output_tensor = encoder_inputs
-        return output_tensor
+        # output_tensor = self.output_normalization(encoder_inputs)
+        return encoder_inputs
diff --git a/aam/models/unifrac_encoder.py b/aam/models/unifrac_encoder.py
index dad09195e1b81b46f1d3e02bcf14d0353f92c2fc..94b0b8c082155c35cef944e83ddf9db97b1de8cf 100644
--- a/aam/models/unifrac_encoder.py
+++ b/aam/models/unifrac_encoder.py
@@ -7,7 +7,9 @@ import tensorflow as tf
 from aam.losses import PairwiseLoss
 from aam.models.base_sequence_encoder import BaseSequenceEncoder
 from aam.models.transformers import TransformerEncoder
-from aam.utils import float_mask
+from aam.optimizers.gradient_accumulator import GradientAccumulator
+from aam.optimizers.loss_scaler import LossScaler
+from aam.utils import apply_random_mask, float_mask
 
 
 @tf.keras.saving.register_keras_serializable(package="UniFracEncoder")
@@ -22,6 +24,13 @@ class UniFracEncoder(tf.keras.Model):
         intermediate_size: int = 1024,
         intermediate_activation: str = "gelu",
         max_bp: int = 150,
+        include_alpha: bool = True,
+        is_16S: bool = True,
+        vocab_size: int = 6,
+        add_token: bool = True,
+        asv_dropout_rate: float = 0.0,
+        accumulation_steps: int = 1,
+        unifrac_metric: str = "faith_pd",
         **kwargs,
     ):
         super(UniFracEncoder, self).__init__(**kwargs)
@@ -34,10 +43,22 @@ class UniFracEncoder(tf.keras.Model):
         self.intermediate_size = intermediate_size
         self.intermediate_activation = intermediate_activation
         self.max_bp = max_bp
+        self.include_alpha = include_alpha
+        self.is_16S = is_16S
+        self.vocab_size = vocab_size
+        self.add_token = add_token
+        self.asv_dropout_rate = asv_dropout_rate
+        self.accumulation_steps = accumulation_steps
+        self.unifrac_metric = unifrac_metric
 
         self.loss_tracker = tf.keras.metrics.Mean()
-        self.unifrac_loss = PairwiseLoss()
-        self.unifrac_tracker = tf.keras.metrics.Mean()
+        if self.unifrac_metric == "unifrac":
+            self.unifrac_loss = PairwiseLoss()
+            self.unifrac_out_dim = self.embedding_dim
+        else:
+            self.unifrac_loss = tf.keras.losses.MeanSquaredError(reduction="none")
+            self.unifrac_out_dim = 1
+        self.encoder_tracker = tf.keras.metrics.Mean()
 
         # layers used in model
         self.base_encoder = BaseSequenceEncoder(
@@ -52,16 +73,12 @@ class UniFracEncoder(tf.keras.Model):
             nuc_attention_layers=3,
             nuc_intermediate_size=128,
             intermediate_activation=self.intermediate_activation,
+            is_16S=self.is_16S,
+            vocab_size=self.vocab_size,
+            add_token=self.add_token,
             name="base_encoder",
         )
 
-        self._unifrac_alpha = self.add_weight(
-            name="unifrac_alpha",
-            initializer=tf.keras.initializers.Zeros(),
-            trainable=True,
-            dtype=tf.float32,
-        )
-
         self.unifrac_encoder = TransformerEncoder(
             num_layers=self.attention_layers,
             num_attention_heads=self.attention_heads,
@@ -70,12 +87,11 @@ class UniFracEncoder(tf.keras.Model):
             activation=self.intermediate_activation,
             name="unifrac_encoder",
         )
-
-        self.unifrac_ff = tf.keras.layers.Dense(
-            self.embedding_dim, use_bias=False, dtype=tf.float32, name="unifrac_ff"
-        )
+        self.unifrac_ff = tf.keras.layers.Dense(self.unifrac_out_dim, dtype=tf.float32)
 
         self.loss_metrics = sorted(["loss", "target_loss", "count_mse"])
+        self.gradient_accumulator = GradientAccumulator(self.accumulation_steps)
+        self.loss_scaler = LossScaler()
 
     def evaluate_metric(self, dataset, metric, **kwargs):
         metric_index = self.loss_metrics.index(metric)
@@ -85,14 +101,16 @@ class UniFracEncoder(tf.keras.Model):
     def _compute_nuc_loss(self, nuc_tokens, nuc_pred):
         return self.base_encoder._compute_nuc_loss(nuc_tokens, nuc_pred)
 
-    def _compute_unifrac_loss(
+    def _compute_encoder_loss(
         self,
         y_true: tf.Tensor,
         unifrac_embeddings: tf.Tensor,
     ) -> tf.Tensor:
         loss = self.unifrac_loss(y_true, unifrac_embeddings)
-        num_samples = tf.reduce_sum(float_mask(loss))
-        return tf.math.divide_no_nan(tf.reduce_sum(loss), num_samples)
+        if self.unifrac_metric == "unifrac":
+            batch = tf.cast(tf.shape(y_true)[0], dtype=tf.float32)
+            loss = tf.reduce_sum(loss, axis=1, keepdims=True) / (batch - 1)
+        return tf.reduce_mean(loss)
 
     def _compute_loss(
         self,
@@ -102,10 +120,10 @@ class UniFracEncoder(tf.keras.Model):
     ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
         nuc_tokens, counts = model_inputs
         embeddings, unifrac_embeddings, nuc_pred = outputs
-        tax_loss = self._compute_unifrac_loss(y_true, unifrac_embeddings)
-        nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred)
-        loss = tax_loss + nuc_loss
-        return [loss, tax_loss, nuc_loss]
+        unifrac_loss = self._compute_encoder_loss(y_true, unifrac_embeddings)
+        # nuc_loss = self._compute_nuc_loss(nuc_tokens, nuc_pred) * 0.0
+        loss = unifrac_loss  # + nuc_loss
+        return [loss, unifrac_loss, 0]  # , nuc_loss]
 
     def predict_step(
         self,
@@ -114,10 +132,10 @@ class UniFracEncoder(tf.keras.Model):
             tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
         ],
     ):
-        inputs, sample_ids = data
-        embeddings, _, _ = self(inputs, training=False)
+        inputs, y = data
+        embeddings, unifrac_embeddings, nuc_pred = self(inputs, training=False)
 
-        return embeddings, sample_ids
+        return unifrac_embeddings
 
     def train_step(
         self,
@@ -126,22 +144,34 @@ class UniFracEncoder(tf.keras.Model):
             tuple[tuple[tf.Tensor, tf.Tensor], tuple[tf.Tensor, tf.Tensor]],
         ],
     ):
-        inputs, y = data
+        if not self.gradient_accumulator.built:
+            self.gradient_accumulator.build(self.optimizer, self)
 
+        inputs, y = data
+        y_target, unifrac_target = y
         with tf.GradientTape() as tape:
             outputs = self(inputs, training=True)
-            loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs)
-
-        gradients = tape.gradient(loss, self.trainable_variables)
-        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+            loss, unifrac_loss, nuc_loss = self._compute_loss(
+                inputs, unifrac_target, outputs
+            )
+            scaled_losses = self.loss_scaler([unifrac_loss])
+            loss = tf.reduce_sum(tf.stack(scaled_losses, axis=0))
+
+        gradients = tape.gradient(
+            loss,
+            self.trainable_variables,
+            unconnected_gradients=tf.UnconnectedGradients.ZERO,
+        )
+        self.gradient_accumulator.apply_gradients(gradients)
 
         self.loss_tracker.update_state(loss)
-        self.unifrac_tracker.update_state(tax_loss)
+        self.encoder_tracker.update_state(unifrac_loss)
         self.base_encoder.nuc_entropy.update_state(nuc_loss)
         return {
             "loss": self.loss_tracker.result(),
-            "unifrac_mse": self.unifrac_tracker.result(),
+            "unifrac_mse": self.encoder_tracker.result(),
             "nuc_entropy": self.base_encoder.nuc_entropy.result(),
+            "learning_rate": self.optimizer.learning_rate,
         }
 
     def test_step(
@@ -152,17 +182,20 @@ class UniFracEncoder(tf.keras.Model):
         ],
     ):
         inputs, y = data
-
+        y_target, unifrac_target = y
         outputs = self(inputs, training=False)
-        loss, tax_loss, nuc_loss = self._compute_loss(inputs, y, outputs)
+        loss, unifrac_loss, nuc_loss = self._compute_loss(
+            inputs, unifrac_target, outputs
+        )
 
         self.loss_tracker.update_state(loss)
-        self.unifrac_tracker.update_state(tax_loss)
+        self.encoder_tracker.update_state(unifrac_loss)
         self.base_encoder.nuc_entropy.update_state(nuc_loss)
         return {
             "loss": self.loss_tracker.result(),
-            "unifrac_mse": self.unifrac_tracker.result(),
+            "unifrac_mse": self.encoder_tracker.result(),
             "nuc_entropy": self.base_encoder.nuc_entropy.result(),
+            "learning_rate": self.optimizer.learning_rate,
         }
 
     def call(
@@ -173,7 +206,45 @@ class UniFracEncoder(tf.keras.Model):
         tokens = tf.cast(tokens, dtype=tf.int32)
         counts = tf.cast(counts, dtype=tf.int32)
 
-        sample_embeddings, nuc_embeddings = self.base_encoder(tokens, training=training)
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        random_mask = None
+        if training and self.asv_dropout_rate > 0:
+            random_mask = apply_random_mask(count_mask, self.asv_dropout_rate)
+
+        sample_embeddings, nuc_embeddings = self.base_encoder(
+            tokens, random_mask=random_mask, training=training
+        )
+
+        if self.add_token:
+            count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        unifrac_gated_embeddings = self.unifrac_encoder(
+            sample_embeddings, mask=count_attention_mask, training=training
+        )
+
+        if self.add_token:
+            unifrac_pred = unifrac_gated_embeddings[:, 0, :]
+        else:
+            mask = tf.cast(count_mask, dtype=tf.float32)
+            unifrac_pred = tf.reduce_sum(unifrac_gated_embeddings * mask, axis=1)
+            unifrac_pred /= tf.reduce_sum(mask, axis=1)
+
+        unifrac_pred = self.unifrac_ff(unifrac_pred)
+        unifrac_embeddings = sample_embeddings + unifrac_gated_embeddings
+        # unifrac_embeddings = unifrac_gated_embeddings
+        return [unifrac_embeddings, unifrac_pred, nuc_embeddings]
+
+    def base_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor]
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings = self.base_encoder.base_embeddings(tokens)
 
         # account for <SAMPLE> token
         count_mask = float_mask(counts, dtype=tf.int32)
@@ -181,7 +252,7 @@ class UniFracEncoder(tf.keras.Model):
         count_attention_mask = count_mask
 
         unifrac_gated_embeddings = self.unifrac_encoder(
-            sample_embeddings, mask=count_attention_mask, training=training
+            sample_embeddings, mask=count_attention_mask
         )
         unifrac_pred = unifrac_gated_embeddings[:, 0, :]
         unifrac_pred = self.unifrac_ff(unifrac_pred)
@@ -190,7 +261,46 @@ class UniFracEncoder(tf.keras.Model):
             sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha
         )
 
-        return [unifrac_embeddings, unifrac_pred, nuc_embeddings]
+        return unifrac_embeddings
+
+    def asv_embeddings(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], training: bool = False
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        asv_embeddings = self.base_encoder.asv_embeddings(tokens, training=training)
+
+        return asv_embeddings
+
+    def asv_gradient(
+        self, inputs: tuple[tf.Tensor, tf.Tensor], asv_embeddings
+    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+        # keras cast all input to float so we need to manually cast to expected type
+        tokens, counts = inputs
+        tokens = tf.cast(tokens, dtype=tf.int32)
+        counts = tf.cast(counts, dtype=tf.int32)
+
+        sample_embeddings = self.base_encoder.asv_gradient(tokens, asv_embeddings)
+
+        # account for <SAMPLE> token
+        count_mask = float_mask(counts, dtype=tf.int32)
+        count_mask = tf.pad(count_mask, [[0, 0], [1, 0], [0, 0]], constant_values=1)
+        count_attention_mask = count_mask
+
+        unifrac_gated_embeddings = self.unifrac_encoder(
+            sample_embeddings, mask=count_attention_mask
+        )
+        unifrac_pred = unifrac_gated_embeddings[:, 0, :]
+        unifrac_pred = self.unifrac_ff(unifrac_pred)
+
+        unifrac_embeddings = (
+            sample_embeddings + unifrac_gated_embeddings * self._unifrac_alpha
+        )
+
+        return unifrac_embeddings
 
     def get_config(self):
         config = super(UniFracEncoder, self).get_config()
@@ -204,6 +314,13 @@ class UniFracEncoder(tf.keras.Model):
                 "intermediate_size": self.intermediate_size,
                 "intermediate_activation": self.intermediate_activation,
                 "max_bp": self.max_bp,
+                "include_alpha": self.include_alpha,
+                "is_16S": self.is_16S,
+                "vocab_size": self.vocab_size,
+                "add_token": self.add_token,
+                "asv_dropout_rate": self.asv_dropout_rate,
+                "accumulation_steps": self.accumulation_steps,
+                "unifrac_metric": self.unifrac_metric,
             }
         )
         return config
diff --git a/aam/models/utils.py b/aam/models/utils.py
index 324d31d8dc4b1a1deb062d82dec356193884ff6e..65bb17f42f263bb552f6d90c36ed8f00af52393e 100644
--- a/aam/models/utils.py
+++ b/aam/models/utils.py
@@ -5,12 +5,9 @@ import tensorflow as tf
 class TransformerLearningRateSchedule(
     tf.keras.optimizers.schedules.LearningRateSchedule
 ):
-    def __init__(
-        self, d_model, warmup_steps=100, decay_method="cosine", initial_lr=3e-4
-    ):
+    def __init__(self, warmup_steps=100, decay_method="cosine", initial_lr=3e-4):
         super(TransformerLearningRateSchedule, self).__init__()
 
-        self.d_model = d_model
         self.warmup_steps = warmup_steps
         self.decay_method = decay_method
         self.initial_lr = initial_lr
@@ -24,12 +21,9 @@ class TransformerLearningRateSchedule(
 
         if self.decay_method == "cosine":
             # Cosine decay after warmup
-            cosine_decay = tf.keras.optimizers.schedules.CosineDecayRestarts(
+            cosine_decay = tf.keras.optimizers.schedules.CosineDecay(
                 initial_learning_rate=self.initial_lr,
-                first_decay_steps=1000,  # Change according to your training steps
-                t_mul=2.0,  # How quickly to increase the restart periods
-                m_mul=0.9,  # Multiplier for reducing max learning rate after each restart
-                alpha=0.0,  # Minimum learning rate
+                first_decay_steps=self.warmup_steps,
             )
             learning_rate = tf.cond(
                 step < self.warmup_steps,
@@ -51,7 +45,6 @@ class TransformerLearningRateSchedule(
         config = {}
         config.update(
             {
-                "d_model": self.d_model,
                 "warmup_steps": self.warmup_steps,
                 "decay_method": self.decay_method,
                 "initial_lr": self.initial_lr,
@@ -61,9 +54,15 @@ class TransformerLearningRateSchedule(
 
 
 def cos_decay_with_warmup(lr, warmup_steps=5000):
-    # Learning rate schedule: Warmup followed by cosine decay
-    lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
-        initial_learning_rate=lr, first_decay_steps=warmup_steps
+    # # Learning rate schedule: Warmup followed by cosine decay
+    # lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
+    #     initial_learning_rate=lr, first_decay_steps=warmup_steps
+    # )
+    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
+        initial_learning_rate=0.0,
+        decay_steps=4000,
+        warmup_target=lr,
+        warmup_steps=warmup_steps,
     )
     return lr_schedule
 
@@ -73,19 +72,23 @@ if __name__ == "__main__":
     import numpy as np
     import tensorflow as tf
 
-    # Learning rate schedule: Warmup followed by cosine decay
-    lr_schedule = cos_decay_with_warmup()
+    cosine_decay = tf.keras.optimizers.schedules.CosineDecay(
+        initial_learning_rate=0.0,
+        decay_steps=4000,
+        warmup_target=0.0003,
+        warmup_steps=100,
+    )
 
-    # Compute learning rates for each step
-    steps = np.arange(10000)
-    learning_rates = [lr_schedule(step).numpy() for step in steps]
+    # Generate learning rates for a range of steps
+    steps = np.arange(4000, dtype=np.float32)
+    learning_rates = [cosine_decay(step).numpy() for step in steps]
 
     # Plot the learning rate schedule
     plt.figure(figsize=(10, 6))
     plt.plot(steps, learning_rates)
-    plt.title("Learning Rate Schedule: Warmup + Cosine Decay")
-    plt.xlabel("Training Steps")
+    plt.xlabel("Training Step")
     plt.ylabel("Learning Rate")
+    plt.title("Transformer Learning Rate Schedule")
+    plt.legend()
     plt.grid(True)
-    plt.savefig("test.png")
-    plt.close()
+    plt.show()
diff --git a/aam/optimizers/gradient_accumulator.py b/aam/optimizers/gradient_accumulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..943f1bf68032038aa6d0ec5343a2655546b2d43b
--- /dev/null
+++ b/aam/optimizers/gradient_accumulator.py
@@ -0,0 +1,92 @@
+import tensorflow as tf
+
+
+class GradientAccumulator:
+    def __init__(self, accumulation_steps):
+        self.accum_steps = tf.constant(
+            accumulation_steps, dtype=tf.int32, name="accum_steps"
+        )
+
+        self.accum_step_counter = tf.Variable(
+            0,
+            dtype=tf.int32,
+            trainable=False,
+            name="accum_counter",
+        )
+        # self.log_steps = tf.constant(100, dtype=tf.int32, name="accum_steps")
+        # self.log_step_counter = tf.Variable(
+        #     0,
+        #     dtype=tf.int32,
+        #     trainable=False,
+        #     name="accum_counter",
+        # )
+        self.built = False
+        # self.file_writer = tf.summary.create_file_writer(
+        #     "/home/kalen/aam-research-exam/research-exam/healty-age-regression/results/logs/gradients"
+        # )
+
+    def build(self, optimizer, model):
+        self.built = True
+        self.optimizer = optimizer
+        self.model = model
+
+        # reinitialize gradient accumulator
+        self.gradient_accumulation = [
+            tf.Variable(
+                tf.zeros_like(v, dtype=tf.float32),
+                trainable=False,
+                name="accum_" + str(i),
+                synchronization=tf.VariableSynchronization.ON_READ,
+                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+            )
+            for i, v in enumerate(self.model.trainable_variables)
+        ]
+
+    def apply_gradients(self, gradients):
+        self.accum_step_counter.assign_add(1)
+
+        # Accumulate batch gradients
+        for i in range(len(self.gradient_accumulation)):
+            self.gradient_accumulation[i].assign_add(
+                gradients[i],
+                read_value=False,
+            )
+
+        # apply accumulated gradients
+        tf.cond(
+            tf.equal(self.accum_step_counter, self.accum_steps),
+            true_fn=self.apply_accu_gradients,
+            false_fn=lambda: None,
+        )
+
+    def apply_accu_gradients(self):
+        """Performs gradient update and resets slots afterwards."""
+        self.optimizer.apply_gradients(
+            zip(self.gradient_accumulation, self.model.trainable_variables)
+        )
+
+        # self.log_step_counter.assign_add(1)
+        # tf.cond(
+        #     tf.equal(self.log_step_counter, self.log_steps),
+        #     true_fn=self.log_gradients,
+        #     false_fn=lambda: None,
+        # )
+
+        # reset
+        self.accum_step_counter.assign(0)
+        for i in range(len(self.gradient_accumulation)):
+            self.gradient_accumulation[i].assign(
+                tf.zeros_like(self.model.trainable_variables[i], dtype=tf.float32),
+                read_value=False,
+            )
+
+    # def log_gradients(self):
+    #     self.log_step_counter.assign(0)
+    #     with self.file_writer.as_default():
+    #         for grad, var in zip(
+    #             self.gradient_accumulation, self.model.trainable_variables
+    #         ):
+    #             if grad is not None:
+    #                 tf.summary.histogram(
+    #                     var.name + "/gradient", grad, step=self.optimizer.iterations
+    #                 )
diff --git a/aam/optimizers/loss_scaler.py b/aam/optimizers/loss_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e6242e584bd8b2be17fdfb715256dce48095214
--- /dev/null
+++ b/aam/optimizers/loss_scaler.py
@@ -0,0 +1,48 @@
+import tensorflow as tf
+
+
+class LossScaler:
+    def __init__(self, gradient_accum_steps):
+        self.gradient_accum_steps = tf.cast(gradient_accum_steps, dtype=tf.float32)
+        self.moving_avg = None
+        self.decay = 0.99
+        self.accum_loss = tf.Variable(
+            initial_value=1,
+            trainable=False,
+            dtype=tf.int32,
+            name="accum_loss",
+        )
+
+    def __call__(self, losses):
+        if self.moving_avg is None:
+            self.scaled_loss = [
+                tf.Variable(
+                    initial_value=loss,
+                    trainable=False,
+                    dtype=tf.float32,
+                    name=f"scaled_loss_{i}",
+                )
+                for i, loss in enumerate(losses)
+            ]
+            self.moving_avg = [
+                tf.Variable(
+                    initial_value=loss,
+                    trainable=False,
+                    dtype=tf.float32,
+                    name=f"avg_loss_{i}",
+                )
+                for i, loss in enumerate(losses)
+            ]
+
+        for i in range(len(self.moving_avg)):
+            self.moving_avg[i].assign(
+                self.moving_avg[i] * self.decay + (1 - self.decay) * losses[i],
+                read_value=False,
+            )
+
+        return [
+            tf.math.divide_no_nan(
+                losses[i], self.moving_avg[i] * self.gradient_accum_steps
+            )
+            for i in range(len(self.moving_avg))
+        ]
diff --git a/aam/utils.py b/aam/utils.py
index 03fb8009c5243ebb543208b20c4fd3afe0363fd3..410a52343bc26bacfe7a7946497f2fdd76de04a4 100644
--- a/aam/utils.py
+++ b/aam/utils.py
@@ -49,7 +49,8 @@ def masked_loss(sparse_cat: bool = False):
 
             total = tf.cast(tf.reduce_sum(mask), dtype=tf.float32)
             loss = tf.reduce_sum(loss * mask)
-            return tf.math.divide_no_nan(loss, total)
+            # return tf.math.divide_no_nan(loss, total)
+            return loss
 
         return wrapper