# Training

In this notebook we will go through the basic process of training a an automatic mixing model. This will involve combining a dataset with a model and an appropriate training loop. For this demonstration we will [PyTorch Lightning](https://www.pytorchlightning.ai/) to faciliate the training. 

## Dataset
For this demonstration we will use the subset of the [DSD100 dataset](https://sigsep.github.io/datasets/dsd100.html). This is a music source separation data, but we will use it to demonstrate how you can train a model. This is a very small subset of the dataset so it can easily be downloaded and we should not expect that our model will perform very well after training. 

This notebook can be used as a starting point for example by swapping out the dataset for a different dataset such as [ENST-drums](https://perso.telecom-paristech.fr/grichard/ENST-drums/) or [MedleyDB](https://medleydb.weebly.com/) after they have been downloaded. Since they are quite large, we will focus only on this small dataset for demonstration purposes. 

## GPU

This notebook supports training with the GPU. You can achieve this by setting the `Runtime` to `GPU` in Colab using the menu bar at the top.

## Learn More

If you want to train these models on your own server and have much more control beyond this demo we encourage you to take a look at the training recipes we provide in the [automix-toolkit](https://github.com/csteinmetz1/automix-toolkit) repository.

But, let's get started by installing the automix-toolkit.

In [None]:
!pip install git+https://github.com/csteinmetz1/automix-toolkit

In [None]:
import os
import torch
import pytorch_lightning as pl
import IPython
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display

from argparse import Namespace

%matplotlib inline
%load_ext autoreload
%autoreload 2

from automix.data import DSD100Dataset
from automix.system import System

First we will download the dataset subset and unzip the archive as well as the pretrained encoder checkpoint.

In [None]:
os.makedirs("checkpoints/", exist_ok=True)
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
!mv encoder.ckpt checkpoints/encoder.ckpt
    
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
!unzip -o DSD100subset.zip 

# Configuration
Here we select where we want to train on CPU or GPU and what model we will use. 

In [None]:
!nvidia-smi # check for GPU

In [None]:
args = {
    "dataset_dir" :  "./DSD100subset",
    "dataset_name" : "DSD100",
    "automix_model" : "dmc",
    "pretrained_encoder" : True,
    "train_length" : 65536,
    "val_length" : 65536,
    "accelerator" : "gpu", # you can select "cpu" or "gpu"
    "devices" : 1, 
    "batch_size" : 4,
    "lr" : 3e-4,
    "max_epochs" : 10, 
    "schedule" : "none",
    "recon_losses" : ["sd"],
    "recon_loss_weights" : [1.0],
    "sample_rate" : 44100,
    "num_workers" : 2,
}
args = Namespace(**args)
    
pl.seed_everything(42, workers=True)

In [None]:
# setup callbacks
callbacks = [
    #LogAudioCallback(),
    pl.callbacks.LearningRateMonitor(logging_interval="step"),
    pl.callbacks.ModelCheckpoint(
        filename=f"{args.dataset_name}-{args.automix_model}"
        + "_epoch-{epoch}-step-{step}",
        monitor="val/loss_epoch",
        mode="min",
        save_last=True,
        auto_insert_metric_name=False,
    ),
]

# we not will use weights and biases
#wandb_logger = WandbLogger(save_dir=log_dir, project="automix-notebook")

# create PyTorch Lightning trainer
# trainer = pl.Trainer(args, callbacks=callbacks)

trainer = pl.Trainer(
    max_epochs=args.max_epochs,
    accelerator=args.accelerator,
    devices=args.devices,
    callbacks=callbacks,
    # Add other trainer arguments here if needed
)

# create the System
system = System(**vars(args))

# Dataset
Now we will create datasets for train/val/test but we will use the same four songs across all sets here for demonstration purposes.

In [None]:
train_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
val_dataset = DSD100Dataset(
    args.dataset_dir,
    args.val_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
test_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=False,
)

# Logging
We can launch an instance of TensorBoard within our notebook to monitor the training process. Be patient, it can take ~60 seconds for the window to show.

In [None]:
%load_ext tensorboard
%tensorboard  --logdir="lightning_logs"

# Train!
Now we are ready to launch the training process.


In [None]:
trainer.fit(system, train_dataloader, val_dataloader)

# Test
After training for a few epochs we will test the system by creating a mix from one of the songs that was in the training set. 

In [None]:
import glob
import torchaudio
start_sample = 262144 * 2
end_sample = 262144 * 3

# load the input tracks
track_dir = "DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/"
track_ext = "wav"

track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
track_names = []
tracks = []
for idx, track_filepath in enumerate(track_filepaths):
    x, sr = torchaudio.load(track_filepath)
    x = x[:, start_sample: end_sample]



    for n in range(x.shape[0]):
      x_sub = x[n:n+1, :]

      gain_dB = np.random.rand() * 12
      gain_dB *= np.random.choice([1.0, -1.0])
      gain_ln = 10 ** (gain_dB/20.0)
      x_sub *= gain_ln 

      tracks.append(x_sub)
      track_names.append(os.path.basename(track_filepath))
      IPython.display.display(ipd.Audio(x[n, :].view(1,-1).numpy(), rate=sr, normalize=True))    
      print(idx+1, os.path.basename(track_filepath))

# add dummy tracks of silence if needed
if system.hparams.automix_model == "mixwaveunet" and len(tracks) < 8:
    tracks.append(torch.zeros(x.shape))

# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)
print(tracks.shape)

# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True)
input_mix /= input_mix.abs().max()
print(input_mix.shape)
plt.figure(figsize=(10, 2))
librosa.display.waveshow(input_mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(1,-1).numpy(), rate=sr, normalize=False))

Above we can hear the tracks with a simple mono mix. Now we will create a mix with the model we just trained.

In [None]:
tracks = tracks.view(1,8,-1)

with torch.no_grad():
  y_hat, p = system(tracks)
  
# view the mix
print(y_hat.shape)
y_hat /= y_hat.abs().max()
plt.figure(figsize=(10, 2))
librosa.display.waveshow(y_hat.view(2,-1).cpu().numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(y_hat.view(2,-1).cpu().numpy(), rate=sr, normalize=True))

# print the parameters
if system.hparams.automix_model == "dmc":
    for track_fp, param in zip(track_names, p.squeeze()):
        print(os.path.basename(track_fp), param)

You should be able to hear that the levels have been adjusted and the sources panned to sound more like the original mix indicating that our system learned to overfit the songs in our very small training set. 