-
Ellison Scheuller authoredEllison Scheuller authored
test_contrastive_vae_tuner.py 11.63 KiB
import torch ### DO NOT REMOVE THIS !!!!!
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model
import qkeras
from qkeras import QBatchNormalization
from qkeras.qlayers import QDense, QActivation
from qkeras.quantizers import quantized_bits, quantized_relu
import tensorflow as tf
import h5py as h5
import numpy as np
from tensorflow import keras
from tqdm import tqdm
from model import *
from utilities import *
import wandb
import ray
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray import tune
import random
import gc
import argparse
def distance(model, data):
mean, _ = model.encoder(data)
reco = model.decoder(mean)
score = tf.keras.losses.mean_squared_error(data,reco)
return score.numpy()
def run(config):
###################################################################
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
tf.random.set_seed(seed)
tf.config.experimental.enable_op_determinism()
####################################################################
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
wandb.login(key="24d1d60ce26563c74d290d7b487cb104fc251271")
wandb.init(project = "ContrastiveVAEQkeras",
settings=wandb.Settings(_disable_stats=True),
config = config)
run_name = wandb.run.name
blur_p = config['blur_p']
blur_m = config['blur_m']
blur_s = config['blur_s']
mask_p = config['mask_p']
beta = config['beta']
VIC_lr = config['vic_lr']
VAE_lr = config['vae_lr']
alpha = config['alpha']
reco_scale = alpha * (1 - beta)
kl_scale = beta
device = 'cuda:0' # This will be set later
Epochs_contrastive = 50 # This will be set later
Epochs_VAE = 480
Batch_size = 4096
vic_encoder_nodes = config['encoder_nodes']
projector_features = vic_encoder_nodes[-1] * 4
vae_encoder_nodes = config['vae_nodes']
vae_latent_dim = config['vae_latent']
# Making it symmetric
vae_decoder_nodes = [vic_encoder_nodes[-1]] + vae_encoder_nodes.copy()
vae_decoder_nodes.reverse()
# --------------------------------------------------------------------------
f = h5.File('/axovol/training/v5/embeddedData_jan25.h5', 'r')
x_train = f['Background_data']['Train']['DATA'][:]
x_test = f['Background_data']['Test']['DATA'][:]
x_train_background = np.reshape(x_train, (x_train.shape[0], -1))
x_test_background = np.reshape(x_test, (x_test.shape[0], -1))
scale = f['Normalisation']['norm_scale'][:]
bias = f['Normalisation']['norm_bias'][:]
l1_bits_bkg_test = f['Background_data']['Test']['L1bits'][:]
# --------------------------------------------------------------------------
feature_blur = FastFeatureBlur(p=blur_p, strength=blur_s, magnitude=blur_m, device=device)
feature_blur_prime = FastFeatureBlur(p=blur_p, strength=blur_s, magnitude=blur_m, device=device)
object_mask = FastObjectMask(p=mask_p, device=device)
object_mask_prime = FastObjectMask(p=mask_p, device=device)
lorentz_rot = FastLorentzRotation(p=0.5, norm_scale=scale, norm_bias=bias, device=device)
lorentz_rot_prime = FastLorentzRotation(p=0.5, norm_scale=scale, norm_bias=bias, device=device)
# --------------------------------------------------------------------------
dataset = torch.tensor(x_train_background, dtype=torch.float32, device=device)
dataset_test = torch.tensor(x_test_background, dtype=torch.float32, device=device)
del x_train_background
gc.collect()
# --------------------------------------------------------------------------
Backbone = ModelBackbone(nodes=vic_encoder_nodes,
ap_fixed_kernel = ap_fixed_kernel,
ap_fixed_bias= ap_fixed_bias,
ap_fixed_activation = ap_fixed_act)
Projection = ModelProjector(projector_features)
model = VICReg(backbone=Backbone, projector=Projection, num_features=projector_features, batch_size=Batch_size)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=VIC_lr))
scheduler = cosine_with_warmup(
max_lr=VIC_lr,
warmup_epochs=10,
decay_epochs=Epochs_contrastive - 10)
for present_epoch in tqdm(range(0, Epochs_contrastive, 1)):
train_loss = 0
train_steps = 0
index = torch.randperm(dataset.shape[0])
current_lr = scheduler.step()
model.optimizer.learning_rate.assign(current_lr)
for i in range(dataset.shape[0] // Batch_size):
with torch.no_grad():
batch = dataset[index[i * Batch_size : (i + 1) * Batch_size]]
batch_x = batch.clone()
batch_y = batch.clone()
batch_x = feature_blur(batch_x)
batch_x = object_mask(batch_x)
batch_x = lorentz_rot(batch_x)
batch_y = feature_blur_prime(batch_y)
batch_y = object_mask_prime(batch_y)
batch_y = lorentz_rot_prime(batch_y)
batch_x,batch_y = batch_x.cpu().numpy(),batch_y.cpu().numpy()
metrics = model.train_step((batch_x, batch_y))
epoch_loss = model.loss_tracker.result().numpy()
epoch_repr = model.loss_tracker_repr.result().numpy()
epoch_std = model.loss_tracker_std.result().numpy()
epoch_cov = model.loss_tracker_cov.result().numpy()
# Reset metrics for next epoch
model.loss_tracker.reset_states()
model.loss_tracker_repr.reset_states()
model.loss_tracker_std.reset_states()
model.loss_tracker_cov.reset_states()
metric_embed = {}
metric_embed['TrainLossC'] = epoch_loss
metric_embed['EpochC'] = present_epoch
metric_embed['LrC'] = current_lr
wandb.log(metric_embed)
#########################################################################################################################
#########################################################################################################################
#########################################################################################################################
vic_encoder = model.backbone
encoder = VAE_Encoder(nodes=vae_encoder_nodes,
feature_size=vae_latent_dim,
ap_fixed_kernel = ap_fixed_kernel,
ap_fixed_bias= ap_fixed_bias,
ap_fixed_activation = ap_fixed_act)
decoder = VAE_Decoder(nodes=vae_decoder_nodes,
ap_fixed_kernel = ap_fixed_kernel,
ap_fixed_bias= ap_fixed_bias,
ap_fixed_activation = ap_fixed_act)
model = VariationalAutoEncoder(encoder=encoder, decoder=decoder)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=VAE_lr))
scheduler = cosine_annealing_warm_restart_with_warmup(first_cycle_steps = 32,
cycle_mult = 2,
max_lr=VAE_lr,
warmup_epochs=10,
gamma=0.5)
dataset_latent = vic_encoder(dataset.cpu().numpy()).numpy()
dataset_latent_test = vic_encoder(dataset_test.cpu().numpy()).numpy()
#############################################
# Signal Data
#############################################
SIGNAL_NAMES = list(f['Signal_data'].keys())
signal_data_dict = {}
signal_l1_dict = {}
for signal_name in SIGNAL_NAMES:
x_signal = f['Signal_data'][signal_name]['DATA'][:]
x_signal = np.reshape(x_signal, (x_signal.shape[0], -1))
x_signal = vic_encoder(x_signal).numpy()
l1_bits = f['Signal_data'][signal_name]['L1bits'][:]
signal_data_dict[signal_name] = x_signal
signal_l1_dict[signal_name] = l1_bits
f.close()
for present_epoch in tqdm(range(0, Epochs_VAE, 1)):
train_loss = 0
train_steps = 0
index = torch.randperm(dataset_latent.shape[0]).numpy()
current_lr = scheduler.step()
model.optimizer.learning_rate.assign(current_lr)
for i in range(dataset_latent.shape[0] // Batch_size):
batch = dataset_latent[index[i * Batch_size : (i + 1) * Batch_size]]
model.train_step(batch)
metric = fast_score(
model=model,
data_bkg=dataset_latent_test,
bkg_l1_bits=l1_bits_bkg_test,
distance_func=distance,
data_signal=signal_data_dict,
signal_l1_bits=signal_l1_dict,
evaluation_threshold=1,
)
total_loss = model.total_loss_tracker.result().numpy()
reco_loss = model.reconstruction_loss_tracker.result().numpy()
kl_loss = model.kl_loss_tracker.result().numpy()
# Reset metrics for next epoch
model.total_loss_tracker.reset_states()
model.reconstruction_loss_tracker.reset_states()
model.kl_loss_tracker.reset_states()
metric['EpochVae'] = present_epoch
metric['LrVae'] = current_lr
metric['TotalLossVae'] = total_loss
metric['RecoLossVae'] = reco_loss
metric['KLLossVae'] = kl_loss
wandb.log(metric)
ray.train.report(metrics=metric)
scheduler.step()
if __name__ == '__main__':
ap_fixed_kernel = [6,2] ### To be further tuned !!!!
ap_fixed_bias = [10,6] ### To be further tuned !!!!
ap_fixed_act = [10,6] ### To be further tuned !!!!
ap_fixed_data = [8,5] ### To be further tuned !!!!
parser = argparse.ArgumentParser()
parser.add_argument('--address', type=str, default=None)
args = parser.parse_args()
if args.address:
ray.init(address=args.address)
else:
ray.init(address='auto')
search_space = {
'vic_lr': tune.loguniform(1e-4, 1e-3),
'vae_lr': tune.loguniform(1e-4, 1e-3),
'blur_p': tune.uniform(0, 1),
'blur_m': tune.uniform(0, 1),
'blur_s': tune.uniform(0, 1),
'mask_p': tune.uniform(0, 1),
'beta': tune.uniform(0, 1),
'alpha': tune.uniform(0, 1),
'encoder_nodes': tune.sample_from(lambda spec: [tune.randint(24, 32).sample(), tune.randint(8, 18).sample()]),
'vae_latent': tune.sample_from(lambda spec: tune.randint(4, 6).sample()),
'vae_nodes': tune.sample_from(lambda spec: [tune.randint(8, 12).sample(), tune.randint(6, 8).sample()]),
}
optuna_search = OptunaSearch(
metric='raw-pure/haa4b-ma15',
mode='max',
)
scheduler = ASHAScheduler(
metric='raw-pure/haa4b-ma15',
mode='max',
max_t=480,
grace_period=32,
reduction_factor=2,
)
analysis = tune.run(
run,
config=search_space,
storage_path='/axovol/raytune',
search_alg=optuna_search,
scheduler=scheduler,
num_samples=2000,
######### resources_per_trial={'cpu': 2, 'gpu': 1 / 4}, #old version
#resources_per_trial={'cpu': 6, 'gpu': 0.5}, #big increase
resources_per_trial={'cpu': 4},
max_concurrent_trials=16
)