Code Monkey home page Code Monkey logo

utilities's Introduction

Lightning Utilities

UnitTests Apply checks Docs Status pre-commit.ci status

This repository covers the following use-cases:

  1. Reusable GitHub workflows
  2. Shared GitHub actions
  3. CLI python -m lightning_utilities.cli --help
  4. General Python utilities in lightning_utilities.core

1. Reusable workflows

Usage:

name: Check schema

on: [push]

jobs:

  check-schema:
    uses: Lightning-AI/utilities/.github/workflows/[email protected]
    with:
      azure-dir: ""  # skip Azure check

  check-code:
    uses: Lightning-AI/utilities/.github/workflows/check-code.yml@main
    with:
      actions-ref: main  # normally you shall use the same version as the workflow

See usage of other workflows in .github/workflows/ci-use-checks.yaml.

2. Reusable composite actions

See available composite actions .github/actions/.

Usage:

name: Do something with cache

on: [push]

jobs:
  pytest:
    runs-on: ubuntu-20.04
    steps:
    - uses: actions/checkout@v3
    - uses: actions/setup-python@v4
      with:
        python-version: 3.9
    - uses: Lightning-AI/utilities/.github/actions/cache
      with:
        python-version: 3.9
        requires: oldest # or latest

3. CLI lightning_utilities.cli

The package provides common CLI commands.

Installation From source:
pip install https://github.com/Lightning-AI/utilities/archive/refs/heads/main.zip

From pypi:

pip install lightning_utilities[cli]

Usage:

python -m lightning_utilities.cli [group] [command]
Example for setting min versions
$ cat requirements/test.txt
coverage>=5.0
codecov>=2.1
pytest>=6.0
pytest-cov
pytest-timeout
$ python -m lightning_utilities.cli requirements set-oldest
$ cat requirements/test.txt
coverage==5.0
codecov==2.1
pytest==6.0
pytest-cov
pytest-timeout

4. General Python utilities lightning_utilities.core

Installation

From pypi:

pip install lightning_utilities

Usage:

Example for optional imports:

from lightning_utilities.core.imports import module_available

if module_available("some_package.something"):
    from some_package import something

utilities's People

Contributors

akihironitta avatar alanhdu avatar alexandervaneck avatar awaelchli avatar borda avatar carmocca avatar ce11an avatar dependabot[bot] avatar ehofesmann avatar ethanwharris avatar gdoongmathew avatar groodt avatar janebert avatar justusschock avatar mauvilsa avatar pre-commit-ci[bot] avatar shenoynikhil avatar skaftenicki avatar wrran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

utilities's Issues

Provide installation instructions in RequirementCache error messaging

๐Ÿš€ Feature

Initial discussion here.

Motivation

Lightning's messaging in errors is always oriented towards providing a resolution for the user. The previous way was to say
"pip install package-name" when a package cloud not be imported. This was recently removed in #14715 with the motivation to add it back to RequirementCache.

Pitch

As proposed by @carmocca

Alternatives

Additional context

Importerror on types.py

๐Ÿ› Bug

ImportError: cannot import name '_PATH' from 'pytorch_lightning.utilities.types' (C:\Users\tmcquaig\anaconda3\envs\Dreambooth-SD\lib\site-packages\pytorch_lightning\utilities\types.py)

To Reproduce

Steps to reproduce the behavior:

  1. Installed Anaconda and https://github.com/JoePenna/Dreambooth-Stable-Diffusion
  2. In the DreamBooth-SD environment created for in Ananconda, run "python main.py"
(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>python main.py
Traceback (most recent call last):
  File "E:\Dreambooth-Stable-Diffusion\main.py", line 2, in <module>
    from ldm.modules.pruningckptio import PruningCheckpointIO
  File "E:\Dreambooth-Stable-Diffusion\ldm\modules\pruningckptio.py", line 4, in <module>
    from pytorch_lightning.utilities.types import _PATH
ImportError: cannot import name '_PATH' from 'pytorch_lightning.utilities.types' (C:\Users\tmcquaig\anaconda3\envs\Dreambooth-SD\lib\site-packages\pytorch_lightning\utilities\types.py)

(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>pip show lightning-utilities
Name: lightning-utilities
Version: 0.4.2
Summary: PyTorch Lightning Sample project.
Home-page: https://github.com/Lightning-AI/utilities
Author: Lightning AI et al.
Author-email: [email protected]
License: Apache-2.0
Location: c:\users\tmcquaig\anaconda3\envs\dreambooth-sd\lib\site-packages
Requires:
Required-by: pytorch-lightning

(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>

Code sample

main.py", line 2, in <module>

`from` ldm.modules.pruningckptio import `PruningCheckpointIO`

pruningckptio.py", line 4, in <module>

`from` pytorch_lightning.utilities.types import `_PATH`

Expected behavior

main.py should successfully run and Dreambooth Stable Diffusion should run :)

Environment

(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>conda info

     active environment : Dreambooth-SD
    active env location : C:\Users\tmcquaig\anaconda3\envs\Dreambooth-SD
            shell level : 2
       user config file : C:\Users\tmcquaig\.condarc
 populated config files : C:\Users\tmcquaig\.condarc
          conda version : 22.11.1
    conda-build version : 3.22.0
         python version : 3.9.13.final.0
       virtual packages : __archspec=1=x86_64
                          __cuda=12.0=0
                          __win=0=0
       base environment : C:\Users\tmcquaig\anaconda3  (writable)
      conda av data dir : C:\Users\tmcquaig\anaconda3\etc\conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/win-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/win-64
                          https://repo.anaconda.com/pkgs/r/noarch
                          https://repo.anaconda.com/pkgs/msys2/win-64
                          https://repo.anaconda.com/pkgs/msys2/noarch
          package cache : C:\Users\tmcquaig\anaconda3\pkgs
                          C:\Users\tmcquaig\.conda\pkgs
                          C:\Users\tmcquaig\AppData\Local\conda\conda\pkgs
       envs directories : C:\Users\tmcquaig\anaconda3\envs
                          C:\Users\tmcquaig\.conda\envs
                          C:\Users\tmcquaig\AppData\Local\conda\conda\envs
               platform : win-64
             user-agent : conda/22.11.1 requests/2.28.1 CPython/3.9.13 Windows/10 Windows/10.0.22621
          administrator : False
             netrc file : None
           offline mode : False


(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>

(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>pip list
Package             Version
------------------- -----------
aiohttp             3.8.3
aiosignal           1.3.1
async-timeout       4.0.2
attrs               22.1.0
certifi             2022.9.24
charset-normalizer  2.1.1
colorama            0.4.6
frozenlist          1.3.3
fsspec              2022.11.0
idna                3.4
lightning-utilities 0.4.2
multidict           6.0.3
numpy               1.23.5
packaging           22.0
pip                 22.3.1
protobuf            3.20.1
pytorch-lightning   1.8.4.post0
PyYAML              6.0
requests            2.28.1
setuptools          65.5.0
tensorboardX        2.5.1
torch               1.13.0
torchmetrics        0.11.0
tqdm                4.64.1
typing_extensions   4.4.0
urllib3             1.26.13
wheel               0.37.1
wincertstore        0.2
yarl                1.8.2

(Dreambooth-SD) E:\Dreambooth-Stable-Diffusion>

Additional context

Add documentation to the rank_zero_only function

๐Ÿ“š Documentation

Document that the rank_zero_only function only works as intended once the actual rank is set, i.e., rank_zero_only.rank = x. This needs to be handled by the application that uses this function.

false positive with `package_available`

๐Ÿ› Bug

Seems we have a bug in the utils package, as this code for installing only PL from source return both true...

To Reproduce

Lightning-AI/pytorch-lightning#16595 (comment)

Code sample

from lightning_utilities.core.imports import package_available
print(module_available("lightning"))
print(module_available("pytorch_lightning"))

Expected behavior

Environment

  • lightning-utilities==0.6.0.post0

Additional context

Update CHANGELOG

๐Ÿš€ Feature

CHANGELOG.md is currently empty. We use GitHub's release notes auto-generation

Motivation

Be consistent

Pitch

Update CHANGELOG.md and the existing releases.

Alternatives

Keep using the auto-generation button

`apply_to_collection` breaks named tuples (force casts to tuples)

๐Ÿ› Bug

lightning_utilities.core.apply_func.apply_to_collection breaks homogeneous named tuples, forcing conversion to simple tuples. This results in a forced conversion of NamedTuple-based batches from the dataloader to tuples in Lightning.

To Reproduce

Reproducing the bug in Lightning requires creating model + dataloader with NamedTuple output.

Reproducing the problem in Lightning-Utilities is easier, see the snippet below:

from lightning_utilities.core.apply_func import apply_to_collection
import torch
from typing import NamedTuple

class NamedTupleBatch(NamedTuple):
    x: torch.Tensor
    y: torch.Tensor

def test_apply_to_collection():
    batch = NamedTupleBatch(x=torch.rand(10, 10), y=torch.rand(10, 10))
    assert isinstance(batch, NamedTupleBatch)  # before
    batch_out = apply_to_collection(batch, torch.Tensor, lambda x: x.to("cpu"))
    assert isinstance(batch_out, NamedTupleBatch)  # after - broken

Expected behavior

apply_to_collection should return the NamedTuple (if input is of NamedTuple type) instead of force-casting it to simple tuple.

Additional context

I think that the problem was introduced in #160
Lightning Utilities version 0.9.0 is fine, but 0.10.0 breaks named tuples.

I think that the behavior is incorrect in

return tuple(function(x, *args, **kwargs) for x in data)

if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
    return tuple(function(x, *args, **kwargs) for x in data)

Potential fix:

if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
    if is_namedtuple(data):
        return type(data)(*[function(x, *args, **kwargs) for x in data])
    else:
        return tuple(function(x, *args, **kwargs) for x in data)
Environment details
  • PyTorch Version (e.g., 1.0): any (problem is not related to PyTorch)
  • OS (e.g., Linux): MacOS, Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

GHA for docs link check

๐Ÿš€ Feature

Motivation

As we moved in several projects to have link validation (detecting dead links) as optional, so if failed job still passes, it is almost as if it is not there because people will ignore it, so have an action that would parse and comment on PR with dead/broken links

Alternatives

make the link check again as required, and part of our build

Additional context

see: Lightning-AI/torchmetrics#1626 (comment)

Rename `lightning_utilities.dev` to `lightning_utilities.cli`

๐Ÿš€ Feature

See title.

Motivation

dev sounds very ambiguous to me.

Alternatives

  • Keep as is.
  • Rename it to a different name.

Additional context

Since the cli subpackage will be used only in our CI, this issue shouldn't be highly prioritised.

Fail to import lightning due to missing dependency on setuptools

Bug description

Lightning appears to depend on the setuptools package at runtime via an import of the pkg_resources module. However, setuptools is not listed as a dependency, causing import failures in some environments.

In particular, when using the PDM package manager, the virtualenv it uses behind the scenes does not by default include a setuptools distribution (as it is typically only needed in the build environment). This means that installing lightning via pdm leads to a failure at import time. However, the issue can be reproduced in a standard virtualenv by simply uninstalling setuptools.

This bug seems to have been introduced in 2.1.x, and I have confirmed it in python 3.9.18 and 3.11.6.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

python -m venv venv
source venv/bin/activate
pip uninstall -y setuptools
pip install lightning

python -c "import lightning" || python -c "import lightning.pytorch"

Error messages and logs

  File "<string>", line 1, in <module>
  File ".../venv/lib/python3.11/site-packages/lightning/__init__.py", line 18, in <module>
    from lightning.fabric.fabric import Fabric  # noqa: E402
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lightning/fabric/__init__.py", line 5, in <module>
    from lightning_utilities.core.imports import package_available
  File ".../venv/lib/python3.11/site-packages/lightning_utilities/__init__.py", line 6, in <module>
    from lightning_utilities.core.apply_func import apply_to_collection
  File ".../venv/lib/python3.11/site-packages/lightning_utilities/core/__init__.py", line 4, in <module>
    from lightning_utilities.core.imports import compare_version, module_available
  File ".../venv/lib/python3.11/site-packages/lightning_utilities/core/imports.py", line 13, in <module>
    import pkg_resources
ModuleNotFoundError: No module named 'pkg_resources'

Environment

Current environment
- PyTorch Lightning Version: 2.1.x
- Python: version 3.9, 3.11 (tested)
- MacOS: 13.5.1 (m1 processor)
- How you installed Lightning: pip, pdm

More info

No response

Cannot import `lightning_utilities.test.warning.no_warning_call`

๐Ÿ› Bug

The test module in lightning_utilities is not part of the package.

To Reproduce

Steps to reproduce the behavior:

from lightning_utilities.test.warning import no_warning_call

Code sample

from lightning_utilities.test.warning import no_warning_call

Expected behavior

Environment

Latest version 0.4.1

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

I got the import from #55

Automatic device transfer converts namedtuples into regular tuples

๐Ÿ› Bug

In the latest version of the utilities library, automatic device transfer converts namedtuples into regular tuples, causing loss of attribute access provided by the namedtuple.

To Reproduce

from collections import namedtuple
import torch
from pytorch_lightning import LightningModule, Trainer

Batch = namedtuple("Batch", ["features", "targets"])

class SimpleModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 1)
        
    def forward(self, batch):
        print(type(batch))
        print(batch.features)
        return self.linear(batch.features)

    def training_step(self, batch, batch_idx):
        return self(batch)

    def train_dataloader(self):
        features = torch.randn(10, 5)
        targets = torch.randint(0, 1, (10,))
        dataset = [Batch(features[i], targets[i]) for i in range(10)]
        return torch.utils.data.DataLoader(dataset, batch_size=1)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = SimpleModule()
trainer = Trainer(fast_dev_run=True)
trainer.fit(model)

This script fails with attribute error AttributeError: 'tuple' object has no attribute 'features', when the library version is 0.10.0 and runs without errors on 0.9.0. The Pytorch lightning version was unchanged, 2.1.2 for both.

Additional context

Environment details
  • PyTorch Version (e.g., 1.0): 2.1.1+cu121
  • OS (e.g., Linux): Ubuntu
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.10.13
  • CUDA/cuDNN version: 12.2

This happened for both single GPU and multi GPU run with DDP. This was a bit of work to track down since the pinned pytorch_lightning version was unchanged.
This is the exact same as a previous lightning issue, although that one seems to have been due to PyTorch?

Introduce ModuleAvailableCache

๐Ÿš€ Feature

Motivation

RequirementCache exists, but it is very general. It works for both modules as well as packages. It has a very strict check for requirements of a package. This check can fail even if the package is actually importable successfully, see example: Lightning-AI/pytorch-lightning#16464

In these cases, the module_available check is better suited for what we want to check. But it can only be called as a function, not as a cache like the RequirementCache.

Pitch

Add ModuleAvailableCache, with a cache implementation like RequirementCache, but we will use the module_available function as as the check function.

Alternatives

User has to resolve any package conflicts in their environment, even if they are harmless.

Additional context

See Lightning-AI/pytorch-lightning#16464

Function to get minumum version of a dependency of a package

๐Ÿš€ Feature

Add a function to lightning_utilities/core/imports.py that returns the minimum version of a dependency of a package.

Motivation

The minimum version of a dependency could be needed in more than one place in the code. When a change is needed, this needs to be done in multiple places. The issue is that it can be forgotten to update one location. An example is requirements/pytorch/extra.txt#L8 and src/pytorch_lightning/cli.py#L32.

Maybe this could be useful in other cases.

Pitch

Be able to change

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.12.0")

to

jsonargparse_min_version = get_minimum_dependency_version("pytorch-lightning", "jsonargparse")
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache(f"jsonargparse[signatures]>={jsonargparse_min_version}")

A possible implementation of this function could be:

def get_minimum_dependency_version(package_name: str, dependency_name: str) -> str:
    """Returns the minimum version of a dependency of a package.

    >>> get_minimum_dependency_version("pytorch-lightning", "jsonargparse") 
    '4.12.0'
    """
    for dependency in metadata.requires(package_name):
        if re.match(f"^{dependency_name}(|[^\w].*)$", dependency):
            return re.sub(r".*>=([\d.]+).*", r"\1", dependency)
    raise ValueError(f"dependency {dependency_name!r} not found in package {package_name!r}")

Alternatives

Keep updating multiple places in the code. When there is a mistake, fix it later.

Additional context

None

Fabric support for TensorDict

Description & Motivation

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device, etc. developed by PyTorch.
Currently Fabric does not support them:

from lightning.fabric import Fabric
import torch
from tensordict import TensorDict


def main():
    fabric = Fabric(devices=2, accelerator="cpu", strategy="ddp")
    fabric.launch()
    d = TensorDict({"a": torch.rand(10, 1, 3), "b": torch.rand(10, 2, 7)}, batch_size=[10])
    gathered = fabric.all_gather(d)
    fabric.print(gathered)
    reduced = fabric.all_reduce(d)
    fabric.print(reduced)


if __name__ == "__main__":
    main()

gives the following error:

Traceback (most recent call last):
  File "/home/belerico/Desktop/lightning-apps/lightning/examples/fabric/reinforcement_learning/test.py", line 15, in <module>
    main()
  File "/home/belerico/Desktop/lightning-apps/lightning/examples/fabric/reinforcement_learning/test.py", line 11, in main
    gathered = fabric.all_gather(d)
  File "/home/belerico/Desktop/lightning-apps/lightning/src/lightning/fabric/fabric.py", line 496, in all_gather
    data = convert_to_tensors(data, device=self.device)
  File "/home/belerico/Desktop/lightning-apps/lightning/src/lightning/fabric/utilities/apply_func.py", line 107, in convert_to_tensors
    data = apply_to_collection(data, src_dtype, conversion_func, device=device)
  File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/lightning_utilities/core/apply_func.py", line 73, in apply_to_collection
    return elem_type(OrderedDict(out))
  File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/tensordict/tensordict.py", line 2888, in __init__
    self._batch_size = self._parse_batch_size(source, batch_size)
  File "/home/belerico/miniconda3/envs/lightning-ai/lib/python3.9/site-packages/tensordict/tensordict.py", line 2905, in _parse_batch_size
    raise ValueError(
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.

Pitch

Let TensorDict be supported by fabric's distributed functions all_gather and all_reduce

Alternatives

No response

Additional context

With the following addition to the https://github.com/Lightning-AI/utilities/blob/main/src/lightning_utilities/core/apply_func.py#L71 method:

+from tensordict import make_tensordict
+from tensordict.tensordict import TensorDictBase

        if isinstance(data, defaultdict):
            return elem_type(data.default_factory, OrderedDict(out))
+      elif isinstance(data, TensorDictBase):
+          return make_tensordict(OrderedDict(out), device=kwargs.get("device", None))  # batch_size is automatically inferred

the above scripts runs without errors:

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 10, 1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([2, 10, 2, 7]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 10]),
    device=None,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 1, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 2, 7]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

I don't know if this modification is enough in every case, but I can investigate more if this it'll became a possible feature.

cc @Borda

Utilise durations of test cases across previous CI runs

Originally suggested by @Borda in Lightning-AI/pytorch-lightning#13673 (comment)

๐Ÿš€ Feature

See title.

Motivation

This feature can further reduce the (GPU) CI time by optimizing the split of standalone tests.

if we somehow sort the test by duration, we can gain a bit more, right bs the batch is long as the slowest test, right?
We could store those metadata on s3 from previous work and re-order. Some extra optim. But not sure we can win that much.
Lightning-AI/pytorch-lightning#13673 (comment)

Pitch

To support this feature, these actions may be useful:

Alternatives

Do nothing.

Additional context

This is definitely one of low priority items for now, but just creating this issue as a possible future work to explore.

Collect and unify pytest results across different runs

๐Ÿš€ Feature

Create a workflow that collects and merges pytest results from CI runs across different operating systems, acclerators and software versions.

With this feature implemented, we will be able to see a merged list of all test cases succeeded/failed/skipped across all CI configurations.

Motivation

We have tests running on across different OS, accelerators and software versions, and currently, each CI run has its own result only, which making it almost impossible to monitor which tests are running or skipped across all such configurations.

Due to this, we've experienced an issue a while ago due to the lack of observability where all of the horovod tests had never run for a long time of period in PL repo.

Pitch

To be explored.
(I guess we could somehow utilise https://github.com/pytest-dev/pytest-reportlog)

Alternatives

To be explored.

Additional context

Codecov automatically merges coverage results uploaded from different CI runs: https://app.codecov.io/gh/Lightning-AI/lightning/
AFAIK, cov result doesn't hold any pytest results, so need to find another way to collect each test case status from different CI settings.

Open for any suggestions ๐Ÿ’œ

Adapt reusable workflows in Lightning Ecosystem

๐Ÿš€ List of workflows to replace

Here's a tracking issue for replacing existing workflows with the new reusable workflows across Lightning repositories.

`apply_func.apply_to_collection` force updating its return type.

๐Ÿ› Bug

When I'm using torchmetrics, I'm implementing a custom metrics for classification with a nametuple return. before lightning_utilities 0.9.0 it works because the returned nametuple instance was not changed by the apply_func.apply_to_collection. Since this behavior was updated after 0.10.0, I'd like to propose that if apply_to_collection can be updated and still keep the same type.

To Reproduce

Steps to reproduce the behavior...

>>> from torchmetrics.utilities.data import _squeeze_scalar_element_tensor, _squeeze_if_scalar
>>> from collections import namedtuple
>>> import torch
>>> State = namedtuple("State",["gt", "tp", "fp", "tn", "fn"])
>>> state = State(torch.tensor(1), torch.tensor(1) ,torch.tensor(1), torch.tensor(1), torch.tensor(1))
>>> x = _squeeze_if_scalar(state)
>>> x
(tensor(1), tensor(1), tensor(1), tensor(1), tensor(1))
Code sample

Expected behavior

>>> x = _squeeze_if_scalar(state)
>>> x
State(gt=tensor(1), tp=tensor(1), fp=tensor(1), tn=tensor(1), fn=tensor(1))

Additional context

Environment details
  • lightning_utilitis: 0.10.0

Proposal

def apply_to_collection(
    data: Any,
    dtype: Union[type, Any, Tuple[Union[type, Any]]],
    function: Callable,
    *args: Any,
    wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
    include_none: bool = True,
    allow_frozen: bool = False,
    **kwargs: Any,
) -> Any:
    if include_none is False or wrong_dtype is not None or allow_frozen is True:
        # not worth implementing these on the fast path: go with the slower option
        return _apply_to_collection_slow(
            data,
            dtype,
            function,
            *args,
            wrong_dtype=wrong_dtype,
            include_none=include_none,
            allow_frozen=allow_frozen,
            **kwargs,
        )
    # fast path for the most common cases:
    if isinstance(data, dtype):  # single element
        return function(data, *args, **kwargs)
    ori_class = data.__class__
    if isinstance(data, list) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous list
        return ori_class(function(x, *args, **kwargs) for x in data)
    if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
        return ori_class(*(function(x, *args, **kwargs) for x in data))
    if isinstance(data, dict) and all(isinstance(x, dtype) for x in data.values()):  # 1d homogeneous dict
        return ori_class(**{k: function(v, *args, **kwargs) for k, v in data.items()})

Define global action.yaml

๐Ÿš€ Feature

Motivation

Seems that particular actions/workflows are not recognized by dependabot and so several of our repo/projects are still using the old version... On the other hand, pin main may have some overhead as we could introduce breaking changes and so broke the particular CI until it is aligned.

So the one global action would aggregate the suggested group of workflows: https://github.com/Lightning-AI/utilities/blob/main/.github/workflows/ci-use-checks.yaml
Also, we can consider based on input/parameter to tun some workflows off ๐Ÿฟ๏ธ

Pitch

allow dependabot to pick up whenever we make a new release

Alternatives

switch to use main by default

Additional context

see: https://docs.github.com/en/actions/creating-actions

Non roundtrippable sequence subclasses raise error in `apply_to_collection`

๐Ÿ› Bug

Some of our dataloaders use custom Sequence subclasses, that cause errors in apply_to_collection.

The reason for this is, that these classes won't roundtrip, which is assumed by the implementation of apply_to_collection.

# minimal example
class X(list):
    def __init__(self, x):
        super().__init__(range(x))

data = X(4)

assert data == type(data)(list(data))  # is False, because of how class X's constructor is implemented
                                       # but this is how apply_to_collection handles Sequences

The above is assumed implicitly here:

out = []
for d in data:
v = _apply_to_collection_slow(
d,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
allow_frozen=allow_frozen,
**kwargs,
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple_ else elem_type(out)

To Reproduce

See example above.

Expected behavior

If an instance of a sequence subclass can't roundtrip it should just be passed through.

Additional context

N/A

Cheers,
Andreas

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.