Code Monkey home page Code Monkey logo

Comments (3)

Rainbow-Six66 avatar Rainbow-Six66 commented on August 24, 2024

This is my code
from disent.metrics import metric_dci, metric_mig
import torch
from torch.utils.data import DataLoader
from disent.metrics import metric_dci, metric_mig, metric_factor_vae
import torch
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from model.β_vae import BetaVAE_H
from disent.dataset.data import DSpritesData
from disent.dataset.transform import ToImgTensorF32

def train():
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
data = DSpritesData()
dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

checkpoint = torch.load('./checkpoints/beta_vae50.pt', map_location=device)
model = BetaVAE_H().to(device)
model.load_state_dict(checkpoint['model'])
model.eval()

# we cannot guarantee which device the representation is on
get_repr = lambda x: model.mlp_encoder(x.to(device))

# evaluate
return {
    **metric_dci(dataset, get_repr, num_train=20, boost_mode='sklearn'),
    **metric_mig(dataset, get_repr, num_train=20),
    **metric_factor_vae(dataset, get_repr, num_train=20),
}

a_results = train()
print('beta=4: ', a_results)

from disent.

nmichlo avatar nmichlo commented on August 24, 2024

Hi there, and thank you!

Unfortunately docs for this are sparse. I understand this is not the most ideal, would gladly accept PRs to fix this.

However, for context, the mig, dci and factor vae scores are largely based on those from https://github.com/google-research/disentanglement_lib (Default values should be similar) From what I remember without looking at the code num_train and batch_size affect the sample size of underlying data that is used to compute the metrics. Too little data and the metrics will be inaccurate, too much and processing time will be too much. Often for metrics during training I would lower these values and then do a final larger compute at the end with the default values.

from disent.

nmichlo avatar nmichlo commented on August 24, 2024

hydra config experiments metrics:
https://github.com/nmichlo/disent/tree/8f061a87076adeae8d6e5b0fa984b660cd40e026/experiment/config/metrics

actual code that selects these:

disent/experiment/run.py

Lines 208 to 209 in 8f061a8

train_metric = [R.METRICS[name].compute_fast] if settings.get("on_train", default_on_train) else None
final_metric = [R.METRICS[name].compute] if settings.get("on_final", default_on_final) else None

metric wrapper:

  • see compute and compute_fast
    class Metric(Generic[T]):
    def __init__(
    self,
    name: str,
    metric_fn: T, # Callable[[...], Dict[str, Number]]
    default_kwargs: Optional[Dict[str, Any]] = None,
    fast_kwargs: Optional[Dict[str, Any]] = None,
    ):
    self._name = name
    self._orig_fn = metric_fn
    self._metric_fn_default = wrapped_partial(self._orig_fn, **(default_kwargs if default_kwargs else {}))
    self._metric_fn_fast = wrapped_partial(self._orig_fn, **(fast_kwargs if fast_kwargs else {}))
    # How do we get a type hint for `__call__` so that its signature matches `T`?
    def __call__(self, *args, **kwargs) -> Dict[str, Number]:
    return self._metric_fn_default(*args, **kwargs)
    @property
    def compute(self) -> T:
    return self._metric_fn_default
    @property
    def compute_fast(self) -> T:
    return self._metric_fn_fast
    @property
    def unwrap(self) -> T:
    return self._orig_fn
    @property
    def name(self) -> str:
    return self._name
    def __str__(self):
    return f"metric-{self.name}"
    def make_metric(
    name: str,
    default_kwargs: Optional[Dict[str, Any]] = None,
    fast_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Callable[[T], Union[Metric[T], T]]:
    """
    Metrics should be decorated using this function to set defaults!
    Two versions of the metric should exist.
    1. Recommended settings
    - This should give reliable results, but may be very slow, multiple minutes to half an
    hour or more for some metrics depending on the underlying model, data and ground-truth factors.
    2. Faster settings
    - This should give a decent results, but should be decently fast, a few seconds/minutes at most.
    This is not used for testing
    """
    # `Union[Metric[T], T]` is hack to get type hint on `__call__`
    def _wrap_fn_as_metric(metric_fn: T) -> Union[Metric[T], T]:
    return Metric(name=name, metric_fn=metric_fn, default_kwargs=default_kwargs, fast_kwargs=fast_kwargs)
    return _wrap_fn_as_metric

fast version kwargs:

NOTE: kwargs for fast versions were arbitrarily chosen. The standard versions should follow kwargs from disentanglement_lib.

NOTE: batch_size is like batch size from dataset loaders, the model is often used within these metrics and is run on the GPU if possible.

from disent.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.