# Inference

In this notebook we will demonstrate how to use two pretrained models to generate multitrack mixes of drum recordings. We provide models trained on the ENST-drums dataset, which features a few hundred drums multitracks and mixes of these multitracks made by professional audio engineers. We train two different multitrack mixing model architectures: the Differentiable Mixing Console (DMC), and the MixWaveUNet. First we will download the model checkpoints and some test audio, then load up the models and the audio tracks and generate a mix that we can listen to. 

Note: This notebook assumes that you have already installed the `automix` package. If you have not done so, you can run the following:

In [None]:
!pip install git+https://github.com/csteinmetz1/automix-toolkit

In [None]:
import os
import glob
import torch
import torchaudio
import numpy as np

import IPython
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display

%matplotlib inline
%load_ext autoreload
%autoreload 2

from automix.system import System

## Download the pretrained models and multitracks
First we will download two different pretrained models. Then we will also download a `.zip` file containing a drum multitrack and the demo mulitrack that were unseen during training. 

In [None]:
# download the pretrained models for DMC and MixWaveUNet trained on ENST-drums dataset
os.makedirs("checkpoints", exist_ok=True)
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/enst-drums-dmc.ckpt
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/enst-drums-mixwaveunet.ckpt
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/medleydb-16-dmc.ckpt
!mv enst-drums-dmc.ckpt checkpoints/enst-drums-dmc.ckpt
!mv enst-drums-mixwaveunet.ckpt checkpoints/enst-drums-mixwaveunet.ckpt
!mv medleydb-16-dmc.ckpt checkpoints/medleydb-16-dmc.ckpt

# then download and extract a drum multitrack from the test set
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/drums-test-rock.zip
!unzip -o drums-test-rock.zip

!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/flare-dry-stems.zip
!unzip -o flare-dry-stems.zip -d flare-dry-stems

In [None]:
!ls

## Set configuration
We have the option to select one of two different checkpoints. 

If we select `enst-drums-dmc.ckpt` we can use the pretrained Differentiable mixing console model which will directly predict gain and panning parameters for each track. On the other hand we can also select `enst-drums-mixwaveunet.ckpt` which will use a multi-input WaveUNet to create a mix of the tracks. To make computation faster we can restrict the maximum number of samples the process with `max_samples`. Using the default `max_samples = 262144` will mix about the first 6 seconds of the track. You can try increasing this value to see how the results change. 

Note: In the case of MixWaveUNet, a power of 2 value for `max_samples` is required.

In [None]:
track_dir = "./drums-test-rock/tracks"
track_ext = "wav"

dmc_ckpt_path = "checkpoints/enst-drums-dmc.ckpt"
mwun_ckpt_path = "checkpoints/enst-drums-mixwaveunet.ckpt"

max_samples = 262144

## Load pretrained model


In [None]:
# load pretrained model
dmc_system = System.load_from_checkpoint(dmc_ckpt_path, pretrained_encoder=False, map_location="cpu").eval()
mwun_system = System.load_from_checkpoint(mwun_ckpt_path, map_location="cpu").eval()

## Load multitrack 
Now we will read the tracks from disk and create a tensor with all the tracks. In this case, we first peak normalize each track to -12 dB which is what the models expect. In the case of MixWaveUNet, we will add an extra track of silence if less than 8 are provided. However, the DMC model can accept any number of tracks, wether more or less than it was trained with.

We can also create a simple mono mixture of these tracks to hear what the multitrack sounds like before we do any mixing. 

In [None]:
# load the input tracks
track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
tracks = []
for idx, track_filepath in enumerate(track_filepaths):
    x, sr = torchaudio.load(track_filepath)
    x = x[:, : max_samples]
    x /= x.abs().max().clamp(1e-8) # peak normalize
    x *= 10 ** (-12/20.0) # set peak to -12 dB
    tracks.append(x)

    plt.figure(figsize=(10, 2))
    librosa.display.waveshow(x.view(-1).numpy(), sr=sr, zorder=3)
    plt.title(f"{idx+1} {os.path.basename(track_filepath)}")
    plt.ylim([-1,1])
    plt.grid(c="lightgray")
    plt.show()
    IPython.display.display(ipd.Audio(x.view(-1).numpy(), rate=sr, normalize=True))    

# add dummy tracks of silence if needed
if len(tracks) < 8:
    tracks.append(torch.zeros(x.shape))

# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)

# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True)
print(input_mix.shape)
plt.figure(figsize=(10, 2))
plt.title("Mono Mix")
librosa.display.waveshow(input_mix.view(-1).numpy(), sr=sr, zorder=3, color="tab:orange")
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(-1).numpy(), rate=sr, normalize=False))

## Generate the DMC mix
Now we can listen to the predicted mix. If we create a mix with the differentiable mixing console we can also print out the gain (in dB) and pan parameter for each track.

In [None]:
# pass tracks to the model and create a mix
with torch.no_grad(): # no need to compute gradients
    mix, params = dmc_system(tracks[:,:-1,:])
print(mix.shape, params.shape)

# view the mix
mix /= mix.abs().max()
plt.figure(figsize=(10, 2))
plt.title("Differentiable Mixing Console")
librosa.display.waveshow(mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(mix.view(2,-1).numpy(), rate=sr, normalize=True))

for track_fp, param in zip(track_filepaths, params.squeeze()):
    print(os.path.basename(track_fp), param)

## Generate the Mix-Wave-U-Net Mix
If we use the MixWaveUNet there are no parameters to show since this model uses a *direct transformation* method which does not use intermediate mixing parameters. 

In [None]:
with torch.no_grad(): # no need to compute gradients
    mwun_mix, params = mwun_system(tracks)
print(mix.shape, params.shape)

# view the mix
mwun_mix /= mwun_mix.abs().max()
plt.figure(figsize=(10, 2))
plt.title("Mix-Wave-U-Net")
librosa.display.waveshow(mwun_mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(mwun_mix.view(2,-1).numpy(), rate=sr, normalize=True))

# MedleyDB
Now we will run DMC that was trained on MedleyDB, which includes many types of instruments. This model was trained with all songs that had 16 or less tracks.

In [None]:
dmc_ckpt_path = "checkpoints/medleydb-16-dmc.ckpt"

# load pretrained model
medley_dmc_system = System.load_from_checkpoint(dmc_ckpt_path, pretrained_encoder=False, map_location="cpu").eval()


## Load tracks
We will use the stems from the song that Gary mixed in the first part of the tutorial.

In [None]:
track_dir = "./flare-dry-stems"
track_ext = "wav"

start_sample = int(32 * 44100)
end_sample = start_sample + int(40 * 44100)

# load the input tracks
track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
tracks = []
track_names = []
for idx, track_filepath in enumerate(track_filepaths):
    x, sr = torchaudio.load(track_filepath)

    if "Vocal" in track_filepath or "Bass" in track_filepath:
      x_L = x[0:1, start_sample:end_sample]
      #x_L /= x_L.abs().max().clamp(1e-8) # peak normalize
      #x_L *= 10 ** (-12/20.0) # set peak to -12 dB
      tracks.append(x_L)
      track_names.append(os.path.basename(track_filepath))

    else:
      x_L = x[0:1, start_sample:end_sample]
      x_R = x[1:2, start_sample:end_sample]

      #x_L /= x_L.abs().max().clamp(1e-8) # peak normalize
      #x_L *= 10 ** (-12/20.0) # set peak to -12 dB

      #x_R /= x_R.abs().max().clamp(1e-8) # peak normalize
      #x_R *= 10 ** (-12/20.0) # set peak to -12 dB

      tracks.append(x_L)
      tracks.append(x_R)
      track_names.append(os.path.basename(track_filepath) + "-L")
      track_names.append(os.path.basename(track_filepath) + "-R")

    plt.figure(figsize=(10, 2))
    librosa.display.waveshow(x_L.view(-1).numpy(), sr=sr, zorder=3)
    plt.title(f"{idx+1} {os.path.basename(track_filepath)}")
    plt.ylim([-1,1])
    plt.grid(c="lightgray")
    plt.show()
    IPython.display.display(ipd.Audio(x_L.view(-1).numpy(), rate=sr, normalize=True))    

# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)

# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True).clamp(-1, 1)
plt.figure(figsize=(10, 2))
plt.title("Mono Mix")
librosa.display.waveshow(input_mix.view(-1).numpy(), sr=sr, zorder=3, color="tab:orange")
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(-1).numpy(), rate=sr, normalize=False))

Now we can create a gain and panning mix of these stems.

In [None]:
# pass tracks to the model and create a mix
with torch.no_grad(): # no need to compute gradients
    mix = medley_dmc_system.model.block_based_forward(tracks, 262144, 262144//2)
#print(mix.shape, params.shape)

# view the mix
mix /= mix.abs().max()
plt.figure(figsize=(10, 2))
plt.title("Differentiable Mixing Console")
librosa.display.waveshow(mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(mix.view(2,-1).numpy(), rate=sr, normalize=True))

#for track_fp, param in zip(track_names, params.squeeze()):
#    print(os.path.basename(track_fp), param)

Certainly not a perfect mix, but notice that the model has learned to raise the level of the vocal, pan it to the center, and try to pan the other elements to the sides.