import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run # you can ignore and remove this
# prepare the data
data = Cars3dData()
size = 64
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
dataset_train = DisentDataset(data, transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))
# dataset_val = ?
# dataset_test = ?
dataloader_train = DataLoader(
dataset=dataset_train,
batch_size=4,
shuffle=True,
num_workers=0,
)
# create the pytorch lightning system
module: pl.LightningModule = Ae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
),
cfg=Ae.cfg(
optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum"
),
)
# train the model
trainer = pl.Trainer(
max_steps=10,
checkpoint_callback=False,
fast_dev_run=is_test_run(),
gpus=1 if torch.cuda.is_available() else None,
)
trainer.fit(module, dataloader_train)
# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))
metrics = {
**metric_dci(
dataset_train, get_repr, num_train=1000, num_test=500, show_progress=True
),
**metric_mig(dataset_train, get_repr, num_train=2000),
}
# evaluate
print("metrics:", metrics)
Any hints are highly appreciated. Thank you for providing this package!