Code Monkey home page Code Monkey logo

class-resolver's Introduction

Class Resolver

Tests Cookiecutter template from @cthoyt PyPI PyPI - Python Version PyPI - License Documentation Status Codecov status DOI Code style: black

Lookup and instantiate classes with style.

๐Ÿ’ช Getting Started

from class_resolver import ClassResolver
from dataclasses import dataclass

class Base: pass

@dataclass
class A(Base):
   name: str

@dataclass
class B(Base):
   name: str

# Index
resolver = ClassResolver([A, B], base=Base)

# Lookup
assert A == resolver.lookup('A')

# Instantiate with a dictionary
assert A(name='hi') == resolver.make('A', {'name': 'hi'})

# Instantiate with kwargs
assert A(name='hi') == resolver.make('A', name='hi')

# A pre-instantiated class will simply be passed through
assert A(name='hi') == resolver.make(A(name='hi'))

๐Ÿค– Writing Extensible Machine Learning Models with class-resolver

Assume you've implemented a simple multi-layer perceptron in PyTorch:

from itertools import chain

from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(self, dims: list[int]):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                nn.ReLU(),
            )
            for in_features, out_features in pairwise(dims)
        ))

This MLP uses a hard-coded rectified linear unit as the non-linear activation function between layers. We can generalize this MLP to use a variety of non-linear activation functions by adding an argument to its __init__() function like in:

from itertools import chain

from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(self, dims: list[int], activation: str = "relu"):
        if activation == "relu":
            activation = nn.ReLU()
        elif activation == "tanh":
            activation = nn.Tanh()
        elif activation == "hardtanh":
            activation = nn.Hardtanh()
        else:
            raise KeyError(f"Unsupported activation: {activation}")
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

The first issue with this implementation is it relies on a hard-coded set of conditional statements and is therefore hard to extend. It can be improved by using a dictionary lookup:

from itertools import chain

from more_itertools import pairwise
from torch import nn

activation_lookup: dict[str, nn.Module] = {
   "relu": nn.ReLU(),
   "tanh": nn.Tanh(),
   "hardtanh": nn.Hardtanh(),
}

class MLP(nn.Sequential):
    def __init__(self, dims: list[int], activation: str = "relu"):
        activation = activation_lookup[activation]
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

This approach is rigid because it requires pre-instantiation of the activations. If we needed to vary the arguments to the nn.HardTanh class, the previous approach wouldn't work. We can change the implementation to lookup on the class before instantiation then optionally pass some arguments:

from itertools import chain

from more_itertools import pairwise
from torch import nn

activation_lookup: dict[str, type[nn.Module]] = {
   "relu": nn.ReLU,
   "tanh": nn.Tanh,
   "hardtanh": nn.Hardtanh,
}

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: str = "relu", 
        activation_kwargs: None | dict[str, any] = None,
    ):
        activation_cls = activation_lookup[activation]
        activation = activation_cls(**(activation_kwargs or {}))
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

This is pretty good, but it still has a few issues:

  1. you have to manually maintain the activation_lookup dictionary,
  2. you can't pass an instance or class through the activation keyword
  3. you have to get the casing just right
  4. the default is hard-coded as a string, which means this has to get copied (error-prone) in any place that creates an MLP
  5. you have to re-write this logic for all of your classes

Enter the class_resolver package, which takes care of all of these things using the following:

from itertools import chain

from class_resolver import ClassResolver, Hint
from more_itertools import pairwise
from torch import nn

activation_resolver = ClassResolver(
    [nn.ReLU, nn.Tanh, nn.Hardtanh],
    base=nn.Module,
    default=nn.ReLU,
)

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: Hint[nn.Module] = None,  # Hint = Union[None, str, nn.Module, type[nn.Module]]
        activation_kwargs: None | dict[str, any] = None,
    ):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation_resolver.make(activation, activation_kwargs),
            )
            for in_features, out_features in pairwise(dims)
        ))

Because this is such a common pattern, we've made it available through contrib module in class_resolver.contrib.torch:

from itertools import chain

from class_resolver import Hint
from class_resolver.contrib.torch import activation_resolver
from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: Hint[nn.Module] = None,
        activation_kwargs: None | dict[str, any] = None,
    ):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation_resolver.make(activation, activation_kwargs),
            )
            for in_features, out_features in pairwise(dims)
        ))

Now, you can instantiate the MLP with any of the following:

MLP(dims=[10, 200, 40])  # uses default, which is ReLU
MLP(dims=[10, 200, 40], activation="relu")  # uses lowercase
MLP(dims=[10, 200, 40], activation="ReLU")  # uses stylized
MLP(dims=[10, 200, 40], activation=nn.ReLU)  # uses class
MLP(dims=[10, 200, 40], activation=nn.ReLU())  # uses instance

MLP(dims=[10, 200, 40], activation="hardtanh", activation_kwargs={"min_val": 0.0, "max_value": 6.0})  # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh, activation_kwargs={"min_val": 0.0, "max_value": 6.0})  # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh(0.0, 6.0))  # uses instance

In practice, it makes sense to stick to using the strings in combination with hyper-parameter optimization libraries like Optuna.

โฌ‡๏ธ Installation

The most recent release can be installed from PyPI with:

$ pip install class_resolver

The most recent code and data can be installed directly from GitHub with:

$ pip install git+https://github.com/cthoyt/class-resolver.git

To install in development mode, use the following:

$ git clone git+https://github.com/cthoyt/class-resolver.git
$ cd class-resolver
$ pip install -e .

๐Ÿ™ Contributing

Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See CONTRIBUTING.rst for more information on getting involved.

๐Ÿ‘‹ Attribution

โš–๏ธ License

The code in this package is licensed under the MIT License.

๐Ÿช Cookiecutter

This package was created with @audreyfeldroy's cookiecutter package using @cthoyt's cookiecutter-snekpack template.

๐Ÿ› ๏ธ For Developers

See developer instructions

The final section of the README is for if you want to get involved by making a code contribution.

โ“ Testing

After cloning the repository and installing tox with pip install tox, the unit tests in the tests/ folder can be run reproducibly with:

$ tox

Additionally, these tests are automatically re-run with each commit in a GitHub Action.

๐Ÿ“ฆ Making a Release

After installing the package in development mode and installing tox with pip install tox, the commands for making a new release are contained within the finish environment in tox.ini. Run the following from the shell:

$ tox -e finish

This script does the following:

  1. Uses BumpVersion to switch the version number in the setup.cfg and src/{{cookiecutter.package_name}}/version.py to not have the -dev suffix
  2. Packages the code in both a tar archive and a wheel
  3. Uploads to PyPI using twine. Be sure to have a .pypirc file configured to avoid the need for manual input at this step
  4. Push to GitHub. You'll need to make a release going with the commit where the version was bumped.
  5. Bump the version to the next patch. If you made big changes and want to bump the version by minor, you can use tox -e bumpversion minor after.

class-resolver's People

Contributors

cmungall avatar cthoyt avatar mberr avatar pkalita-lbl avatar rusty1s avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

class-resolver's Issues

Entrypoint loader

Add an extra option to the initializer that corresponds to an entrypoint for classes

Error with torch ReduceLROnPlateau after upgrading

I get an error importing the torch aggregation resolver, as seen in the error below. This started happening once I upgraded to torch 2.2.0 from 2.1.2

Traceback (most recent call last):
  File "/home/guyaglionby/workflow/runner.py", line 10, in <module>
    from pykeen.nn.representation import TextRepresentation
  File "/opt/conda/lib/python3.10/site-packages/pykeen/nn/__init__.py", line 7, in <module>
    from . import init
  File "/opt/conda/lib/python3.10/site-packages/pykeen/nn/init.py", line 19, in <module>
    from .text import TextEncoder, text_encoder_resolver
  File "/opt/conda/lib/python3.10/site-packages/pykeen/nn/text.py", line 11, in <module>
    from class_resolver.contrib.torch import aggregation_resolver
  File "/opt/conda/lib/python3.10/site-packages/class_resolver/contrib/torch.py", line 213, in <module>
    lr_scheduler_resolver.register(ReduceLROnPlateau)
  File "/opt/conda/lib/python3.10/site-packages/class_resolver/base.py", line 170, in register
    raise RegistrationNameConflict(self, key, element, label="name")
class_resolver.base.RegistrationNameConflict: Conflict on registration of name reducelronplateau:
Existing: <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>
Proposed: <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>

Looks like the line that's an issue:

lr_scheduler_resolver.register(ReduceLROnPlateau)

I wonder if something like the below might fix it, but I'm not familiar enough with this library to be sure:

from packaging import version
if version.parse(torch.__version__) < version.parse("2.2.0"):
    lr_scheduler_resolver.register(ReduceLROnPlateau)

Thanks for the library!

Docstring for optimizer_resolver is wrong

Hi,
the docstring for optimizer_resolver is wrong, because the Optimizer class does not have a model parameter. I think you don't have xdoctests set up for this project? The offending lines:

def train(optimizer: Hint[Optimizer] = "adam", optimizer_kwargs: OptionalKwargs = None):
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = optimizer_resolver.make(optimizer, optimizer_kwargs, model=model)

The problem is, that even xdoctests would not catch this error, because the inside of the train function would not be evaluated.
Maybe something like this would be better?

import torch
from torch import nn
from class_resolver import Hint, OptionalKwargs
from class_resolver.contrib.torch import optimizer_resolver
from torch.nn import Parameter
from torch.optim import Optimizer

optimizer: Hint[Optimizer] = "adam"
optimizer_kwargs: OptionalKwargs = None
model = nn.Linear(10,2)
optimizer = optimizer_resolver.make(optimizer, optimizer_kwargs, params=model.parameters())

I can do a PR to setup xdoctests and fix the error (and any other upcoming ones) if you want.

Document the synonyms mechanism (in particular, normalizing underscores)

We use the synonym mechanism in OAK and many other projects, the docs for this are a bit buried.

Also there appears to be magic that happens that can cause hard to debug errors; e.g.

my_parser_resolver.synonyms.update(
   "my_name": ActualParser,
   ...

Invalid Parser name: my_name Valid choices are: ['my_name']

What seems to be happening is undocumented normalization removing all underscores

Add method to lookup class

Sometimes we want to lookup a class, i.e., not pass-through instances, but map them to their class.

Currently, we can do so with the following work-around:

instance_or_cls = tokenizer_resolver.lookup(instance)
cls = instance_or_cls.__class__ if isinstance(instance_or_cls, BaseClass) else instance_or_cls

Discussion: Infer suffix from base class name

If base class defines some interface, e.g. Similarity, the child classes often use this as suffix, e.g. DotProductSimilarity and CosineSimilarity. In such case, using the base class as a default suffix may be useful.

Thus, I propose to extend Resolver.from_subclasses to use the (normalized) name of the base class as suffix, if None is provided. To disable this, a user could explicitly pass the empty string.

Conflict between PyTorch and PyTorch Lightning

Thanks for the package! It looks like there is a conflict between PyTorch and PyTorch Lightning when making use of class_resolver.contrib.torch. The following can re-produce the issue:

import torch
import pytorch_lightning
from class_resolver.contrib.torch import optimizer_resolver
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/rusty1s/miniconda3/lib/python3.7/site-packages/class_resolver/contrib/torch.py", line 71, in <module>
    suffix="LR",
  File "/Users/rusty1s/miniconda3/lib/python3.7/site-packages/class_resolver/api.py", line 159, in from_subclasses
    **kwargs,
  File "/Users/rusty1s/miniconda3/lib/python3.7/site-packages/class_resolver/api.py", line 119, in __init__
    suffix=suffix,
  File "/Users/rusty1s/miniconda3/lib/python3.7/site-packages/class_resolver/base.py", line 106, in __init__
    self.register(element)
  File "/Users/rusty1s/miniconda3/lib/python3.7/site-packages/class_resolver/base.py", line 153, in register
    raise RegistrationNameConflict(self, key, element, label="name")
class_resolver.base.RegistrationNameConflict: Conflict on registration of name exponential:
Existing: <class 'pytorch_lightning.tuner.lr_finder._ExponentialLR'>
Proposed: <class 'torch.optim.lr_scheduler.ExponentialLR'>

Combine subclass and entrypoint instantiation

While the entrypoint loader in PyKEEN seems to be very clever, it's still sort of a problem for collab users. It might be better to provide a mixed version where there are some base classes that are in PyKEEN, plus the possibility to add extra stuff

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.