Code Monkey home page Code Monkey logo

tensordict's Introduction

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming | **TensorDict for parameter serialization | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, data in enumerate(dataset):
    # the model reads and writes tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")

But that is not all, you can also store nested values in a tensordict:

>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

and any nested tuple structure will be unravelled to make it easy to read code and write ops programmatically:

>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True)  # this works too!

You can also store non-tensor data in tensordicts:

>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(data.view(-1))
torch.Size([12])
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())

apply() can also be great to filter a tensordict, for instance:

data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> with params.to_module(model):
...     out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> def func(x, params):
...     with params.to_module(model):
...         return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and (soon) torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than regular calls to torch.save(state_dict). Moreover, because tensors will be saved independently on disk, you can deserialize your checkpoint on an arbitrary slice of the model.

>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16)  # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model)  # load onto model
>>> params["0"].to_module(model[0])  # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0])  # load on a slice of the model

The same functionality can be used to access data in a dataset stored on disk. Soring a single contiguous tensor on disk accessed through the tensordict.MemoryMappedTensor primitive and reading slices of it is not only much faster than loading single files one at a time but it's also easier and safer (because there is no pickling or third-party library involved):

# allocate memory of the dataset on disk
data = TensorDict({
    "images": torch.zeros((128, 128, 3), dtype=torch.uint8),
    "labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])]  # This is much faster than loading the 3 images independently

Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via TensorDict.map which will dispatch a task to various workers:

import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
    images = data.get("images").flip(-2).clone()
    labels = data.get("labels") // 10
    # we update the td inplace
    data.set_("images", images)  # flip image
    data.set_("labels", labels)  # cluster labels

if __name__ == "__main__":
    # create data_preproc here
    data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True)  # process 1 images at a time

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    data = TensorDict({}, batch_size=[])
    data["a"] = torch.randn(3)
    data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return data

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

data = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

data = TensorDict({}, batch_size=[N])
for i in range(N):
    data[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> data = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = data.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = data["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

tensordict's People

Contributors

albertbou92 avatar alexanderlobov avatar apbard avatar dependabot[bot] avatar goldspear avatar khundman avatar kurt-stolle avatar lucifer1004 avatar matteobettini avatar mischab avatar rmax avatar roccajoseph avatar romainjln avatar salaxieb avatar se-yi avatar skandermoalla avatar smorad avatar sreevasthav avatar sugatoray avatar tcbegley avatar vmoens avatar wonnor-pro avatar xmaples avatar xuehaipan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensordict's Issues

[BUG] Wrong batch_size

Describe the bug

The shape of a batch is the full length of the dataset

To Reproduce

I am using version =0.0.3

import torch
from tensordict import MemmapTensor
from tensordict.prototype import tensorclass
from torch.utils.data import DataLoader

@tensorclass
class Data:
    images: torch.Tensor
    targets: torch.Tensor

    @classmethod
    def from_tensors(cls,  device: torch.device):
        len_dataset = 1000
        images, targets = torch.randn((len_dataset, 28, 28)), torch.randn((len_dataset, 1))
        data = cls(
            images=MemmapTensor.from_tensor(images, filename="images.dat"),
            targets=MemmapTensor.from_tensor(targets, filename="targets.dat"),
            batch_size=[len_dataset],
            device=device,
        )

        data.memmap_()

        return data
    
data = Data.from_tensors(device="cuda")

dl = DataLoader(data, batch_size=64, collate_fn=lambda x: x)

for batch in dl:
    print(batch.images.shape)

The output is, should be 64 as first shape (the batch_size)

torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])
torch.Size([1000, 28, 28])

[Feature Request] `tensordict.entry_dtype(entry)`, `tensordict.entry_device(entry)`

Motivation

In #168 we introduce tensordict._entry_class. Maybe it would make sense to have the same for dtype and device to avoid doing things like if tensordict.get(key).dtype is torch.float32.
These methods should be made public and documented.
We should find a proper way to do it with lazy tensordicts (especially stack). One option would be to take the first tensordict of the list and query that method on it, implicitely assuming that homogeneous tensordicts have been stacked. This is prone bug if one stacks tensordicts with varying dtypes but that's a behaviour that is not supported anyway (as usual: do we want to apply expensive checks or just tell the users to be cautious because we don't?)

cc @tcbegley

cc @matteobettini fyi since this implies lazy-stacked tds

[Feature Request] allow infering the batch size from a dict when creating a tensordict instance

Motivation

Suppose I have a dictionary as follows:

x = {'a':torch.randn(4,3), 'b': torch.randn(4,5)}

If I want to convert this to a tensordict, I will get an error if I don't specify the batch size:

y = TensorDict(x)  # won't run

In this case, can we support inferring the batch size from the first common dimension?

This is quite useful when integrating tensordict with other libraries. For example, in HuggingFace Accelerate, it converts the network output to the right precision by using ConvertOutputsToFp32. But it won't work with Tensordict currently because Tensordict does not allow creating a tensordict instance without explicitly specifying the batch size. The specific line that errors out is this.

[BUG] Error in sequential indexing of TensorDict with MemmapTensor

Describe the bug

Sampling two dimensions in a consecutive way produces an error when using MemmapTensor. If a standard Tensor is used, no error is raised.

To Reproduce

Steps to reproduce the behavior.

import tensordict, numpy, sys, torch
from tensordict import TensorDict
torch.manual_seed(4584371022706121143)

A, B = 10, 2
SAMPLE_SIZE = 1
tensordict = TensorDict({"a": torch.rand(A, B, 2), "b": torch.rand(A, B, 1)}, [A, B])
print("b: ", tensordict["b"])
tensordict.memmap_()
print(tensordict)

# sample first dimension
idx = torch.randint(0, A, (SAMPLE_SIZE,))
tensordict1 = tensordict[idx]
print(tensordict1)

# sample second dimension
step_idx = torch.randint(0, B, (SAMPLE_SIZE,))
tensordict2 = tensordict1[range(SAMPLE_SIZE), step_idx]
print(tensordict2)

tensordict2['b'] * 0.9

Output and stack traces

b:  tensor([[[0.4750],
         [0.1718]],

        [[0.0098],
         [0.7934]],

        [[0.2763],
         [0.5519]],

        [[0.2940],
         [0.8162]],

        [[0.3873],
         [0.9035]],

        [[0.6649],
         [0.1111]],

        [[0.4237],
         [0.0228]],

        [[0.2010],
         [0.4994]],

        [[0.4820],
         [0.4220]],

        [[0.0147],
         [0.4110]]])
TensorDict(
    fields={
        a: MemmapTensor(shape=torch.Size([10, 2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MemmapTensor(shape=torch.Size([10, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 2]),
    device=None,
    is_shared=False)
TensorDict(
    fields={
        a: MemmapTensor(shape=torch.Size([1, 2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MemmapTensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1, 2]),
    device=None,
    is_shared=False)
TensorDict(
    fields={
        a: MemmapTensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        b: MemmapTensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)
Traceback (most recent call last):
  File "td_test.py", line 27, in <module>
    tensordict2['b'] * 0.9
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 610, in __mul__
    return torch.mul(self, other)
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 418, in __torch_function__
    args = tuple(a._tensor if hasattr(a, "_tensor") else a for a in args)
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 418, in <genexpr>
    args = tuple(a._tensor if hasattr(a, "_tensor") else a for a in args)
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 431, in _tensor
    return self._load_item(self._index)
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 390, in _load_item
    memmap_array = self._get_item(_idx, memmap_array)
  File ".local/lib/python3.8/site-packages/tensordict/memmap.py", line 375, in _get_item
    memmap_array = memmap_array[idx]
  File ".conda/envs/testtd/lib/python3.8/site-packages/numpy/core/memmap.py", line 334, in __getitem__
    res = super().__getitem__(index)
IndexError: index 1 is out of bounds for axis 1 with size 1

Expected behavior

If the line tensordict.memmap_() is commented, the code does not produces any error.

System info

conda create -n testtd pip python==3.8
pip install torchrl

tensordict.__version__:  0.1.2
numpy.__version__:  1.24.3
sys.version:  3.8.0 (default, Nov  6 2019, 21:49:08)
[GCC 7.3.0]
sys.platform:  linux
torch.__version__:  2.0.1+cu117

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Feature Request] Make the last sync with torchrl

Motivation

TorchRL's version of tensordict and modules has changed a bit since we started this work.

From now on, TensorDict will only evolve in this repo. However we still need to do a last manual sync with the library.

[BUG] Indexing discrepancies with numpy

Describe the bug

TensorDict indexing behavior differs from numpy. In most cases, the numpy behavior seems to make more sense.

Steps to reproduce and expected behavior

import numpy as np
from tensordict import TensorDict

nd = np.random.random((3,4))
nd_5d = np.random.random((5, 3, 4, 6, 7))
td = TensorDict({}, batch_size=(3, 4))
td_5d = TensorDict({}, batch_size=(5, 3, 4, 6, 7))

# None following ellipsis are used as the first dimension instead of the last
nd[1, ..., None].shape  # (4, 1)
td[1, ..., None].shape  # torch.Size([1, 4])

# Same issue
nd[None, 1, ..., None].shape  # (1, 4, 1)
td[None, 1, ..., None].shape  # RuntimeError: Not enough dimensions in TensorDict for index provided.

# Multiple None following ellipsis raise RuntimeError
nd[..., :2, None, None].shape  # (3, 2, 1, 1)
td[..., :2, None, None].shape  # RuntimeError: Not enough dimensions in TensorDict for index provided.

# Tuples within tuples are handled as lists by numpy, while we raise a NotImplementedError. 
# Already documented in https://github.com/pytorch-labs/tensordict/issues/99.
nd[(0, 1), (0, 1)].shape  # (2,)
td[(0, 1), (0, 1)].shape  # NotImplementedError: batch dim cannot be computed for type <class 'tuple'>. 

# Very specific edge case. 
nd_5d[2:, [[[0, 1]]], :3, 0].shape  # (1, 1, 2, 3, 3, 7)
td_5d[2:, [[[0, 1]]], :3, 0].shape  # torch.Size([3, 1, 1, 2, 3, 7])
# However, we get the same output when using a list index rather than an integer
nd_5d[2:, [[[0, 1]]], :3, [0]].shape  # (1, 1, 2, 3, 3, 7)
td_5d[2:, [[[0, 1]]], :3, [0]].shape  # torch.Size([1, 1, 2, 3, 3, 7])

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.1.0+c03956a 1.24.2 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] darwin 1.13.1

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[Feature Request] tensorclass with non-tensor data

Motivation

@tensorclass is currently restricted to input types: Tensor, MemmapTensor, KeyedJaggedTensor, TensorDictBase or other tensorclass
Some users have expressed interest for a more generic type of content.

Solution

We could store any other data but restrict the tensor-behaviours to the tensor data.

API decisions

returning a new instance

Any operation that return a tensorclass would just copy the non-tensor data:

>>> @tensorclass
... class MyClass:
...     images: torch.Tensor
...     path: str
... 
>>> data = MyClass(torch.randn(2, 3, 64, 64), "path_to_data", batch_size=[2])
>>> print(data[0])
MyData(images=torch.Tensor(...), path="path_to_data", batch_size=[])

The same would do for .contiguous(), .clone(), .unbind(), etc
For stack or cat we would need to check that these attributes match. If not, an error is thrown.

### Indexing

In general, I think every non-tensor field should be treated in the same way, ie a list would not be indexed

>>> @tensorclass
... class MyClass:
...     images: torch.Tensor
...     captions: List[str]
... 
>>> data = MyClass(torch.randn(2, 3, 64, 64), [“a”, “b”], batch_size=[2])
>>> print(data[0])
MyData(images=torch.Tensor(...), captions=["a", "b"], batch_size=[])

If the number of captions is limited this could work

@tensorclass
class MyClass:
    images: torch.Tensor
    _captions: torch.Tensor
    caption_classes: List[str]

    @property 
    def captions(self):
        return [self.caption_classes[idx] for idx in self._captions]

data = MyClass(images=torch.randn(5, 3, 64, 64), _captions=torch.randint(2, (5,)), caption_classes=["a", "b"], batch_size=[5])

Otherwise, if the number of captions is much larger, it would be up to the user to implement indexing

@tensorclass
class MyClass:
    images: torch.Tensor
    captions: List[str]

    def __getitem__(self, item):
        c = super().__getitem__(item)
        c.captions = self.captions[item]
        return c

setitem

With tensorclass, we can do

>>> @tensorclass
... class MyClass:
...     images: torch.Tensor
>>> data = MyClass(torch.randn(3, 4, 5), batch_size=[3, 4])
>>> data[1] = MyClass(torch.randn(4, 5), batch_size=[4])

With a non-tensor field, I would think about solving it that way

>>> @tensorclass
... class MyClass:
...     images: torch.Tensor
...     meta_data: Optional[str] = None
>>> data = MyClass(torch.randn(3, 4, 5), "stuff", batch_size=[3, 4])
>>> data[1] = MyClass(torch.randn(4, 5), "stuff", batch_size=[4])  # match - no error
>>> data[1] = MyClass(torch.randn(4, 5), batch_size=[4])  # ignored - no error
>>> data[1] = MyClass(torch.randn(4, 5), "I am a pig, and as a pig, I have always stood out.", batch_size=[4])  # don't match - raise exception
>>> MyClass(torch.randn(3, 4, 5), batch_size=[3, 4])
>>> data[1] = MyClass(torch.randn(4, 5), "I am a pig, and as a pig, I have always stood out.", batch_size=[4])  # don't match - raise exception

Construction

One important question is also whether we want to allow any field to be a non-tensor, e.g.

>>> @tensorclass
... class MyClass:
...     images: torch.Tensor  # do we want to raise an exception if a string is given?
...     meta_data: Any  # will tensors be considered as tensors?
...     meta_data: Optional[str] = None  # do we want to raise an exception if a tensor is given?

cc @tcbegley @keyboardAnt

[Feature Request] Caching keys

Motivation

Getting the list of keys can be expensive, especially when we ask for the nested keys.
We could perhaps cache them.

Updating the tensordict (either using set, update, setdefault, rename, select, exclude) would trigger an update of the keys.

The major blocker for this is that modifying a nested tensordict should impact the parent tensordict, which implementation may open the door to a pandora box that we may want to keep close (ie having a parenting mechanism similar to replay buffers or transforms):

subtd = TensorDIct({"b": [3]}, [])
tensordict = TensorDict({"a": subtd}, [])
subtd["c"] = [1]  # tensordict should be informed about this too

cc @tcbegley

[Feature Request] key-level granularity in `skip_existing`

Currently skip_existing operates on all keys without any granularity.

This is a problem in RL when in a loss module for example you may want to skip existing "values" but you definitely never want to skip existing "memory" in a memory based model (RNN). AKA if you use skip_existing on memory keys you will never update your memory.

This is needed to support rnns in torch rl (issue pytorch/rl#1060)

We need a solution to make skip_existing more granular.

This is really simple and consists in feeding to the set_skip_existing funtion the keys we actually want to skip.

with set_skip_existing(["value", "value_target]):
     loss(td) # Will use existing values but not existing hidden memory

by default, if no keys are passed, the beahvior remains the same as the current set_skip_existing=True

[BUG] 'distribution_kwargs' are not correctly passed to dist in 'ProbabilisticTensorDictModule'

Describe the bug

Hi All,

It seems to me that if I set custom arguments distribution_kwargs for a distribution with ProbabilisticTensorDictModule, these arguments only get assigned to the object (see tensordict/tensordict/nn/probabilistic.py line 215):

self.distribution_kwargs = (
            distribution_kwargs if distribution_kwargs is not None else {}
        )

but later on, when the dist is fetched def get_dist(self, tensordict: TensorDictBase) -> d.Distribution:, they are not used (see tensordict/tensordict/nn/probabilistic.py line 225):

dist_kwargs = {
    dist_key: tensordict[td_key]
    for dist_key, td_key in self.in_keys.items()
}
dist = self.distribution_class(**dist_kwargs)

In my case, this led to the issue that a custom min/max range for the dist didn't get set correctly as can be seen in the following small example.

To Reproduce

import torch
from tensordict.nn import ProbabilisticTensorDictModule
from tensordict import TensorDict
from torchrl.modules.distributions.continuous import TanhNormal

min_dist = torch.Tensor([-10, -10])
prob_module = ProbabilisticTensorDictModule(
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": min_dist,
        "max": min_dist * -1,
    },
    return_log_prob=True,
)
print(f'prob_module.distribution_kwargs[\'min\']: {prob_module.distribution_kwargs["min"]}')

td = TensorDict(
    {
        'loc': torch.Tensor([0, 0]),
        'scale': torch.Tensor([0.1, 0.1]),

    }, [])

dist = prob_module.get_dist(td)
print(f'dist.min {dist.min}')

assert all(dist.min == min_dist)

Expected behavior

distribution_kwargs should be used when calling ProbabilisticTensorDictModule.get_dist().

System info

tensordict.version == 2023.01.09
numpy.version == 1.23.2
sys.version == 3.8.15 (default, Oct 12 2022, 19:15:16) \n[GCC 11.2.0]
sys.platform == linux
torch.version == 2.0.0.dev20230108+cu116

Reason and Possible fixes

In tensordict/tensordict/nn/probabilistic.py line 229, just adding **self.distribution_kwargs should fix it:

dist = self.distribution_class(**dist_kwargs, **self.distribution_kwargs)

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[BUG] Auto-nested tensordict bugs

Describe the bug

Auto-nesting may be a desirable feature (e.g. to build graphs), but currently it is broken for multiple functions, e.g.

tensordict = TensorDict({}, [])
tensordict["self"] = tensordict
print(tensordict)  # fails
tensordict.flatten_keys()  # fails
list(tensordict.keys(include_nested=True))  # fails

Consideration

This is something that should be included in the tests. We could design a special test case in TestTensorDictsBase with a nested self.

## Solution

IMO there is not a single solution to this problem. For repr, could find a way of representing a nested tensordict, something like

TensorDict(fields={
   "self": ...
}, 
batch_size=[])

For keys, we could avoid returning a key if it a key pointing to the same value has already been returned (same for values and items).
For flatten_keys, it should be prohibited for TensorDict. The options are (1) leave it as it is since the maximum recursion already takes care of it or (2) build a wrapper around flatten_keys() to detect if the same method (i.e. the same call to the same method from the same class) is occurring twice, something like

def detect_self_nesting(fun):
    def new_fun(self, *args, **kwargs):
        if fun in self._being_called:
             raise RuntimeError
        self._being_called.append(fun)
        out = fun(self, *args, **kwargs)
        self._being_called.pop()
        return out
    return new_fun

There are probably other properties that I'm missing, but i'd expect them to be covered by the tests if we can design the dedicated test pipeline mentioned earlier.

[Feature Request] Saving and loading memmap tensordicts / tensorclasses

Motivation

We can easily save tensordicts using torchsnapshot. However, as MemmapTensors store a file on disk, it would be fairly easy to save a tensordict that contains only MemmapTensors (is_memmap() returns True).
Ideally, the saved tensordict structure would follow the one of the original tensordict (+ metadata)

tensordict = TensorDict({"a": torch.randn(3, 4), {"b": torch.randn(3, 4)}}, [3, 4])
tensordict.memmap_()
save_memmap_tensordict(tensordict, "/path/to/save")

would result in

/path/to/save/metadata.pt
/path/to/save/a.memmap
/path/to/save/b/metadata.pt
/path/to/save/b/c.memmap

(we'd need metadata for each subtensordict since they may have a different device / batch_size).

Loading from such file would also be easy (and would not create a copy):

tensordict_loaded = load_from_memmap("/path/to/save/", mode="r+")
assert tensordict_loaded.is_memmap()
assert (tensordict_loaded == tensordict).all()

This should work for tensordict and tensorclass, but a little extra work may be needed for the latter.

  • TensorDict
  • tensorclass

@sreevasthav

[BUG] functorch.dim breaks split of big tensors

Describe the bug

For some reason (probably unrealted to tensordict), functorch dims fails with big tensors.

import torch
from tensordict import TensorDict

# build splits
total = 0
l = []
while total < 6400:
    l.append(torch.randint(2, 10, (1,)).item())
    total += l[-1]

# build td
tensordict = TensorDict({"x": torch.randn(6401, 1), "y": torch.randn(6401, 1)}, [])

# split
for k, x in tensordict.items():
    x.split(l, 0)

This should result in an error like this:

libc++abi: terminating with uncaught exception of type c10::Error: allocated_ <= ARENA_MAX_SIZE INTERNAL ASSERT FAILED at "/stuff/pytorch/pytorch/pytorch/functorch/csrc/dim/arena.h":227, please report a bug to PyTorch. 

Default batch_size + TensorDictModule can be surprising

TensordictModule allows to use a tensordict/class as input to the forward module. But it also support the nn.Module style call with tensors.
Under the hood it creates a new tensordict that then forwards to the call method. In doing so, however, it uses the default batch_size.

I would say that using the maximum size is probably not the most common case. Can we reconsider the default or find a way to get the right size from the model itself?

[BUG] Conflicting keys silently collide in `tensordict.unflatten_keys`

Describe the bug

If two keys have the same resulting index in flatten_keys and unflatten_keys, they will be overwritten without an error being raised. This may lead to unexpected behaviours down the line.

To Reproduce

We would expect this code to fail:

from tensordict import TensorDict
import torch
t = TensorDict({'a.a': torch.randn(3), 'a': {'a': torch.randn(3)}}, [])
print(t.unflatten_keys('.'))

as ("a", "a") already exists.
Similarly,

from tensordict import TensorDict
import torch
t = TensorDict({'a.a': torch.randn(3), 'a': {'a': torch.randn(3)}}, [])
print(t.flatten_keys('.'))

should also fail as "a.a" is also an existing key.

[Feature Request] tensordict.split

Motivation

It would be nice to have a tensordict.split method that would work as tensor.split.
We would also need to overload torch.split as we do for torch.stack if possible.

Example:

tensordict = TensorDict({}. [10])
tensordict.split(5, 0)  # results in 2 tensordict of batch_size [5] and [5]
tensordict.split([5, 3, 2], 0)  # results in 3 tensordict of batch_size [5], [3] and [2]
torch.split(tensordict, [5, 3, 2], 0)  # same as above

[BUG] Incorrect output batch_size with functorch.vmap

Describe the bug

When used with functorch.vmap, TensorDictModule does not give the correct batch_size when in_dims and out_dims has non-zero items.

To Reproduce

A minimal example to reproduce:

import torch
import torch.nn as nn
import functorch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, make_functional

a = TensorDict({
    "a": torch.rand(1024, 3, 64),
    "b": torch.rand(1024, 3, 32),
}, batch_size=[1024, 3])

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = nn.Linear(64, 128)
        self.b = nn.Linear(32, 64)
    
    def forward(self, a, b):
        return self.a(a), self.b(b)

model = TensorDictModule(Model(), in_keys=["a", "b"], out_keys=["out.a", "out.b"])

# option 1
fmodel, params = functorch.make_functional(model)
functorch.vmap(fmodel, in_dims=(None, 1), out_dims=1)(params, a)

# option 2
params = make_functional(model)
functorch.vmap(model, in_dims=(1, None), out_dims=1)(a, params)

# option 3
functorch.vmap(model, 1, 1)(a)

Expected behavior

The expected output batch_size is [1024, 3]. But all three options give a batch_size of [3, 1024] (however, the entries' shapes are correct):

TensorDict(
    fields={
        a: Tensor(torch.Size([1024, 3, 64]), dtype=torch.float32),
        b: Tensor(torch.Size([1024, 3, 32]), dtype=torch.float32),
        out.a: Tensor(torch.Size([1024, 3, 128]), dtype=torch.float32),
        out.b: Tensor(torch.Size([1024, 3, 64]), dtype=torch.float32)},
    batch_size=torch.Size([3, 1024]),
    device=None,
    is_shared=False)

System info

tensordict: 0.0.1c
torch: 1.13.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

ImportError: cannot import name 'TensorDict' from partially initialized module 'tensordict'

Describe the bug

You have a circular import somewhere

To Reproduce

from tensordict import TensorDict
import torch
 from tensordict import TensorDict
ImportError: cannot import name 'TensorDict' from partially initialized module 'tensordict' (most likely due to a circular import) (/home/ubuntu/francesco/playground-with-torchdata/tensordict.py)

Describe the characteristic of your environment:

I am using pip install tensordict-nightly

Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 14 2022, 12:59:47)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1026-aws-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 470.161.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==1.13.1
[pip3] torchdata==0.5.1
[pip3] torchvision==0.14.1
[conda] Could not collect

[Feature Request] Support for TensorDictBase.masked_select inplace

Motivation

Giving the ability of .masked_select() inplace for the TensorDictBase.

Solution

Giving the ability of .masked_select() like but modification-inplace for the TensorDictBase, by a method named .masked_select_().

Main steps to achieve this:

  1. Iterate key-values and collect masked tensors for values with type leaf tensor
  2. Iterate key-values with type of nested TensorDict, and call recursively .masked_select_()
  3. Modify the batch_size to the correct

Examples:

td = TensorDict(source={'a': torch.zeros(3, 4)},
    batch_size=[3])
mask = torch.tensor([True, False, False])
td.masked_select_(mask)
td.get("a")
#output: tensor([[0., 0., 0., 0.]])

[BUG] Improvements on MemmapTensor

Describe the bug

We need to do some woodwork on MemmapTensor:

  • Remove the ownership and replace it by a check over the number of times the array is being opened (for temporary files only)
  • Solve a bug where calling MemmapTensor(memmap_tensor[idx]) returns the wrong view of the MemmapTensor #193
  • Improve MemmapTensor creation args, including: creating a Memmap on a specific location (#189)
  • Improve the way we copy a Memmap from one place to another, if we want and if we don't want to change the filename
  • simplify the logic of MemmapTensor so that extracting a tensor is less confusing (i.e. check as_tensor, _tensor, _load_item and how they interact).
  • the device attribute is confusing, as the data is never really on another device but gets loaded there if needed. The usage is that one can put the device attr as a destination, then move the Memmap from process to process and when accessing the data it'll be on the right device. However, this may give the false illusion that the content of the tensor is stored on the device...

cc @tcbegley

[Feature Request] Nested keys compatibility improvement

Motivation

We need to improve the compatibility with nested keys. This should include:

  • TensorDict.get(nested_key)
  • TensorDict.set(nested_key, value) (+ creation of the sub-td if the original key is missing)
  • TensorDict.set_(nested_key, value)
  • TensorDict.keys()
  • TensorDict.select(*keys)
  • TensorDict.exclude(*keys)
  • TensorDict.set_default(key, value)
  • TensorDictModule
  • TensorDictSequential

We will cover once it's implemented

  • TensorDict.pop(key)

Contentious methods / unresolved issues:

  • items: we could return the first level by default and return nested ones if a flag is set to true?
  • values: same reason
  • meta-tensor: building the features in parallel for meta-tensor will double the amount of work (?)
  • empty sub-tensordict behaviour: default behaviour would return only leaves, but one could ask for every level OR to return leaf-tensordicts.

cc @tcbegley

[Feature Request] More arguments for `td.memmap_()`

Motivation

Since TensorDict.memmap_() creates a copy on physical mem of the content of a tensordict, it would be nice to have some more options:

  • MemmapTensors are temporary files. It could be nice to disable this when calling memmap_() (like temporary=False). In that way, one would know that the content of the memmap will be kept in memory after the process exits.
  • We should also be able to pass the path to where the files have to be saved (`path="/path/to/my_tensordict"). This would work nicely with #176
  • Similarly, it could be neat to be able to control if a memmap_ tensor can be written or not (by passing mode="r+" for instance). Of course, memmap_() would always write the content but then modify the permissions accordingly to make sure that the content is somewhat locked for in-place modifs.

[BUG] Unexpected shape change during pre-allocation

Describe the bug

The use case is to implement a lazily initialized rollout buffer for RL. However, when setting a nested TensorDict with finer-grained batch_size at index 0, the allocated buffer tensors get truncated batch_size while also changing the shape of the original TensorDict.

To Reproduce

import torch
from tensordict import TensorDict

buffer = TensorDict({}, batch_size=[500, 1024])

td_0 = TensorDict({
    "env.time": torch.rand(1024, 1),
    "agent.obs": TensorDict({ # assuming 3 agents in a multi-agent setting
        "image": torch.rand(1024, 3, 64),
        "state": torch.rand(1024, 3, 3, 32, 32)
    }, batch_size=[1024, 3])
}, batch_size=[1024])

td_1 = td_0.clone()
buffer[0] = td_0
buffer[1] = td_1

This would trigger

RuntimeError: indexed destination TensorDict batch size is torch.Size([1024]) (batch_size = torch.Size([500, 1024]), index=1), which differs from the source batch size torch.Size([1024, 3])

at buffer[1] = td_1. And the batch sizes of both buffer["b"] and td_0["b"] become [1024].

Expected behavior

The batch sizes of both buffer["b"] and td_0["b"] being [1024, 3], and buffer[1] = td_1 being successful.

System info

tensordict: 0.0.1c

Reason and Possible fixes

The cause looks to be around https://github.com/pytorch-labs/tensordict/blob/main/tensordict/tensordict.py#L1973 and https://github.com/pytorch-labs/tensordict/blob/main/tensordict/tensordict.py#L3393 where the target shape is unexpectedly considered to be incongruent:

# L1973, TensorDictBase.__setitem__
def __setitem__(
    self, index: INDEX_TYPING, value: Union[TensorDictBase, dict]
) -> None:
    ...
    if value.batch_size != indexed_bs:
        raise RuntimeError(
            f"indexed destination TensorDict batch size is {indexed_bs} "
            f"(batch_size = {self.batch_size}, index={index}), "
            f"which differs from the source batch size {value.batch_size}"
        )
    ...

#L3393, SubTensorDict.set
...
def set(
    self,
    key: NESTED_KEY,
    tensor: Union[dict, COMPATIBLE_TYPES],
    inplace: bool = False,
    _run_checks: bool = True,
) -> TensorDictBase:
    ...
    if isinstance(tensor, TensorDictBase) and tensor.batch_size != self.batch_size:
        tensor.batch_size = self.batch_size
    ...

A possible fix would be changing to if value.batch_size[: len(indexed_bs)] != indexed_bs and if isinstance(tensor, TensorDictBase) and tensor.batch_size[: len(self.batch_size)] != self.batch_size, respectively. All existing tests still pass after the modification.

It's really a minor problem and later I found torchrl.data.TensorDictReplayBuffer with LazyTensorStorage to be a working alternative. But it's also interesting to see they are implemented differently.

I'm glad to fix it if you would like.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[BUG] Error in `__repr__` of heterogeneous tensordicts

Describe the bug

An error is thrown in __repr__ of a tensordict when rolling out an env that returns heterogeneous lazy stacks.

To Reproduce

   from torchrl.envs.libs.vmas import VmasEnv
   print(VmasEnv("simple_crypto", num_envs=32).rollout(10, return_contiguous=False))
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/VectorizedMultiAgentSimulator/vmas/examples/torch_rl.py", line 59, in <module>
    print(env.rollout(10, return_contiguous=False))
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 1725, in __repr__
    fields = _td_fields(self)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 5699, in _td_fields
    sorted([_make_repr(key, item, td) for key, item in td.items_meta()])
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 5699, in <listcomp>
    sorted([_make_repr(key, item, td) for key, item in td.items_meta()])
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 5691, in _make_repr
    return f"{key}: {repr(tensordict.get(key))}"
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 1725, in __repr__
    fields = _td_fields(self)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 5699, in _td_fields
    sorted([_make_repr(key, item, td) for key, item in td.items_meta()])
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 5699, in <listcomp>
    sorted([_make_repr(key, item, td) for key, item in td.items_meta()])
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 859, in items_meta
    yield k, self._get_meta(k)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 570, in _get_meta
    return self._dict_meta[key]
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/utils.py", line 267, in __missing__
    value = self.fun(key)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 4261, in _make_meta
    return torch.stack(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/metatensor.py", line 327, in __torch_function__
    return META_HANDLED_FUNCTIONS[func](*args, **kwargs)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/metatensor.py", line 496, in stack_meta
    return _stack_meta(
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/metatensor.py", line 464, in _stack_meta
    shape = list(shape)
TypeError: 'NoneType' object is not iterable

[BUG] Performance drops after long time

Hello,

thank you for creating TensorDict, it helps me a lot.

I use TensorDict as rolling replay buffer in a Reinforcement Learning application. I need lots of storage since my data involves images. The memmap feature is very handy in this case. Recently, I started seeing sharp performance drops during training after running for quite some time. This usually indicates that there are some tensors leaking and are not being collected.

After days I was able to track this issue down to my replay buffer implementation, which you can find below. The problem occurs when running many iterations over time. The performance drop seems to be caused more by buffer.get() rather than buffer.put(), even though buffer.put() also causes rather small performance drops over time. This happens for both small and large dicts.

There does not seem to be any issue when writing static data once (on iteration 0) and reading afterwards. Results seem to be cached and reading is super fast with no drop (at least I do not observe any).

import os
import shutil
import sys
import uuid

import numpy as np
import psutil
import torch
from tensordict import TensorDict, MemmapTensor
from tqdm import tqdm


class TensorDictReplayBuffer:
    def __init__(self, max_size, prefix):
        super().__init__()
        self.max_size = max_size
        self.current_index = 0
        self.current_size = 0
        self.prefix = prefix

        self.cache = None
        self.separator = "."

    @property
    def size(self):
        return self.current_size

    def setup_cache(self, data: TensorDict):
        os.makedirs(self.prefix, exist_ok=True)
        template = {
            k: MemmapTensor(
                self.max_size, *v.shape,
                dtype=v.dtype, filename=os.path.join(self.prefix, k)
            )
            for k, v in data[0].to_tensordict().items()
        }
        self.cache = TensorDict(template, batch_size=[self.max_size])

    def put(self, data: TensorDict) -> None:
        data = data.flatten_keys(self.separator).clear_device_()
        if self.cache is None:
            self.setup_cache(data)

        n = len(data)
        indices = torch.arange(self.current_index, self.current_index + n) % self.max_size
        self.cache[indices] = data
        self.current_size = min(self.max_size, self.current_size + n)

    def get(self, batch_size: int) -> TensorDict:
        indices = np.random.choice(self.current_size, batch_size, replace=False)
        return self.cache[indices].unflatten_keys(self.separator)


def data():
    n = 10
    return TensorDict(
        dict(
            dummy=torch.ones(n, n)
        ),
        batch_size=[]
    )

shutil.rmtree("data", ignore_errors=True)
os.makedirs("data", exist_ok=True)
buffer = TensorDictReplayBuffer(max_size=1_000_000, prefix=f"data/{uuid.uuid4()}")
n_reads = 10000
n_per_iteration = 5
batch_size = 128

for i in range(1000):
    with tqdm(total=n_reads, file=sys.stdout, desc=f"Iteration {i}") as pbar:
        for _ in range(n_reads // n_per_iteration):

            trajectory = torch.stack([data() for _ in range(n_per_iteration)])

            # if i == 0:
            buffer.put(trajectory)

            for _ in range(n_per_iteration):
                if buffer.size >= batch_size:
                    batch = buffer.get(batch_size=batch_size)

                pbar.update(1)

        pbar.set_postfix(dict(
            gpu_mem_mb=torch.cuda.memory_allocated() / 1e6,
            cpu_mem_mb=psutil.Process().memory_info().rss / 1e6
        ))

Is this something that can be explained given my implementation? Is this even an intended use-case for mem-mapped dicts? Or is there any other way to accomplish this?

Any pointers are appreciated. Thanks a lot.

[Feature Request] Benchmarks in GH actions

Motivation

We could create a benchmarking workflow in gh action and log these after very merge.

This would probably involve turning the benchmarks into pytest?

The alert is super cool too!

Ideally we'd like to have one workflow on CPU and one on GPU.

Resources:

@tcbegley pinging you for context. I will create a BC task with this, but feel free to add anything you think is relevant.

[Feature Request] Improve first-class dimensions support

Motivation

#5 added support for first-class dimensions. There is however at least one missing feature and some ways we could consider improving this:

Missing Feature

It should be possible to index / order with tuples of first-class dimensions for convenient flattening / reshaping operations. Consider the following example from the torchdim repo.

i, j, k = dims(3)
j.size = 2
A = torch.rand(6, 4)
a = A[(i, j), k] # split dim 0 into i,j
print(i.size, j.size, k.size)
> 3 2 4

r = a.order(i, (j, k)) # flatten j and k
print(r.shape)
> torch.Size([3, 8])

Currently this will lead to an error if attempted on a TensorDict.

Possible improvement

One possible enhancement would be to allow the user to specify first-class dimensions in the batch_size argument when instantiating a TensorDict. E.g.

d1, d2 = dims(2)
td = TensorDict({}, batch_size=[d1, d2, 3, 4])

# roughly equivalent to
td = TensorDict({}, batch_size=[1, 2, 3, 4])[d1, d2]

The difficulty here is likely to be that if the tensordict is empty, we do not know the size of the first class dimensions and they will be unbound until something is added to the tensordict. Assuming that can be worked around it might be a nice convenience?

[Feature Request] `get` and `set` tuple support

Motivation

Like for __getindex__ and __setindex__, we should cover get and set usage with tuples. If the key is not present, set should create an empty TensorDict with same batch size and same device and populate it with the desired value.

Cons:

This may cause some overhead. I would not use a try/except which is expensive, just a regular isinstance.

Cc @tcbegley

[Feature Request] Operations between TensorDicts

Motivation

It would be really useful if there would be the possibility to perform operations such as sum, multiplication etc... between two TensorDicts ( key-wise). Is this a feature that you have in mind to introduce in future?

Right now I noticed that is not possible.

Checklist

  • [ X] I have checked that there is no similar issue in the repo (required)

[Feature Request] TensorDict to support data classes

Motivation

Keeping track of what keys go into the tensordict can be challenging across a large codebase. Especially given the tensordicts interior key mutability where different modules could return tensordicts with different key naming conventions.

Solution

  1. Statically typed tensordict for a dataclass where every member of the data class must either be a tensordict or a PyTorch tensor.
  2. Being able to create tensordicts with lists of named tuples/dataclasses.
  3. Being able to index into the tensordict and get back the data class you created the tensordict with.

Alternatives

I haven't thought too much about alternatives.

Additional context

This is heavily inspired by Jax's data class style approach for functional programming. It seems to me that at a high level, tensordicts are trying to achieve a similar thing to Jax's usage of pytrees except specifically with tensors.

Checklist

  • [ x] I have checked that there is no similar issue in the repo (required)

[Feature Request] More informative shape in `LazyStackedTensorDict` with heterogenous tensors

Motivation

Currently, when creating a stack of tensordicts with different shapes, the shape of the heterogenous componenets is reporetd as *.

It would be nice to have a more informative shape printed out, in the following way:

td1 = TensorDict({"a": torch.randn(3, 4, 3, 255, 256)}, [3, 4])
td2 = TensorDict({"a": torch.randn(3, 4, 3, 254, 256)}, [3, 4])
td = torch.stack([td1, td2], 0)

Before:

print(td)
LazyStackedTensorDict(
    fields={
        a: Tensor(*, dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

proposed change

print(td)
LazyStackedTensorDict(
    fields={
        a: Tensor([2,3,4,*,256], dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

As mentioned in #135

[BUG] repr of SubTensorDict with nested items seems wrong

Describe the bug

When a SubTensorDict contains nested tensordicts, the representation of those tensordicts doesn't take into account the slicing done by SubTensorDict so the shape information is a bit confusing.

To Reproduce

import torch
from tensordict import TensorDict

td = TensorDict({"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2])}, [1])
std = td.get_sub_tensordict(0)

Now compare

print(std)
# SubTensorDict(
#     fields={
#         a: TensorDict(
#             fields={
#                 b: Tensor(torch.Size([1, 2, 1]), dtype=torch.float32)},
#             batch_size=torch.Size([1, 2]),
#             device=None,
#             is_shared=False)},
#     batch_size=torch.Size([]),
#     device=None,
#     is_shared=False)

whereas

print(std["a"])
# TensorDict(
#     fields={
#         b: Tensor(torch.Size([2, 1]), dtype=torch.float32)},
#     batch_size=torch.Size([2]),
#     device=None,
#     is_shared=False)

which is different to what is listed under the fields of std above.

[Bug] Indexing i:i+1 not handled correctly

from tensordict import tensorclass
import torch

@tensorclass
class MyData:
    x: torch.Tensor
a = MyData(x=torch.ones((3,10)), batch_size=[3])
b = MyData(x=torch.zeros((10)), batch_size=[])
# this works
a[1] = b
# this raises but should not
a[1:2] = b

[BUG] Partially instantiated sub-tensordict silently collide with flatten ones when unflattened

Describe the bug

When a tensordict is created with mixed nested and flattened keys, a call to unflatten_keys silently squashes them against each other.

To Reproduce

from tensordict import TensorDict
tensordict = TensorDict({"a": [1, 2], "c.a": [1, 2], "c": TensorDict({"b": [1, 2]}, [])}, [])
print(tensordict.unflatten_keys("."))

This produces the following output

TensorDict(
    fields={
        a: Tensor(torch.Size([2]), dtype=torch.int64),
        c: TensorDict(
            fields={
                a: Tensor(torch.Size([2]), dtype=torch.int64)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

where the "c.b" key has disappeared.

Expected behavior

Both "c.b" and "c.a" should be there.

[Feature Request] clone instead of to_tensordict

Motivation

The behaviour of clone is rather poorly defined. We should use to_tensordict instead.
Some clone methods rely on copy, which is messed up for tensordict classes as it does keeps pointers to previous metadata.
Having a clone that only returns pure TensorDict would be cleaner.

Fastest way to load TensorDict data?

Hi there! First of all thanks for your work, I really enjoy tensordicts!

I have a question about performance. What is the best way to load TensorDicts?
I ran some test on possible combinations of dataloaders / collate_fn:


Simple benchmarking code

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tensordict import TensorDict

a = torch.rand(1000, 50, 2)
td = TensorDict({"a": a}, batch_size=1000)

Case 1: store data as tensors, create TensorDicts on the run

class SimpleDataset(Dataset):
    def __init__(self, data):
        # We split into a list since it is faster to dataload (fair comparison vs others)
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
dataset = SimpleDataset(td['a'])
dataloader = DataLoader(dataset, batch_size=32, collate_fn=torch.stack)
x = TensorDict({'a': next(iter(dataloader))}, batch_size=32)
%timeit for x in dataloader: TensorDict({'a': x}, batch_size=x.shape[0])

520 µs ± 833 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Case 2: store data as TensorDicts and directly load them

class TensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [d for d in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
data  = TensorDictDataset(td)
# use collate_fn=torch.stack to avoid StopIteration error
dataloader = DataLoader(data, batch_size=32, collate_fn=torch.stack)
x = next(iter(dataloader))
%timeit for x in dataloader: pass

1.72 ms ± 5.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Case 3: store TensorDict data as dictionaries and create TensorDicts on the run with collate_fn

class CustomTensorDictDataset(Dataset):
    def __init__(self, data):
        self.data = [
            {key: value[i] for key, value in data.items()}
            for i in range(data.shape[0])
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class CustomTensorDictCollate(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch):
        return TensorDict(
            {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
            batch_size=len(batch),
        )
    
data = CustomTensorDictDataset(td)
dataloader = DataLoader(data, batch_size=32, collate_fn=CustomTensorDictCollate())
x = next(iter(dataloader))
%timeit for x in dataloader: pass

567 µs ± 924 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Apparently, splitting data into dictionaries and creating TensorDicts on the run is the fastest way to load data... but why is it not faster to just index TensorDicts instead? And is there a better way?

[Feature Request] Support for nested tensors

Motivation

Support for nested tensors should work as torch.stack(tensordicts, 0)
(we can't currently override torch.nested.as_nested_tensor so we'll need to write a custom op)

We don't need to subclass TensorDict for this but there are a few caveats:

  1. Reshaping won't be allowed
  2. Shape can't be accessed
  3. Indexing can only be done on the first dim

For instance, populating a tensordict with a nested tensor won't work as we can't access the shape.

  • If nested tensors are found in a tensordict, all tensors should have that feature.

  • Indexing such a tensordict along the second dimension will require:

  1. splitting the tensors that are nested
  2. indexing those tensors
  3. re-nesting them

[Feature Request] data class like behaviour

Motivation

Some users have pointed out that data classes offer advantages over dict-like classes:

  • predictability: one knows in advance what fields to expect, the code is transparent although less generic
  • autocomplete: IDEs can keep track of the attributes through type hinting
  • dedicated behaviour: a class can have a set of dedicated methods that act on its attributes.

Proposed API

We could have a mixed approach where a data class would build a tensordict and interface with it. Attributes would point to keys. The methods would be shared across classes.

@tensordict_data  # better name?
class MyData:
   float_tensor
   sparse_tensor

   def stuff(self):
       return self.float_tensor + self.sparse_tensor 

data = MyData(float_tensor=a, sparse_tensor=b, batch_size=[3, 4])
data.reshape(-1) # returns a new MyData with shape/ batch_size of [12]
data.batch_size # picks the batch size of the tensordict
data.missing_key # returns a dedicated error since this key is not expected

Challenges

Any such dataclass methods (such as reshape, split etc) should return a new instance of the same class.
Upon creation (at loading time) one should check that the data class attributes do not collide with the tensordict ones.
Nesting with data classes will be slightly harder than with tensordict (see this thread)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.