Code Monkey home page Code Monkey logo

s4's People

Contributors

ad12 avatar albertfgu avatar hongyuhe avatar jchia avatar krandiash avatar rogerni avatar telmop avatar trellixvulnteam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

s4's Issues

Comparison of S4 with stateful transformers

Hi,

Wu et al. recently published a paper on Memorizing Transformers (transformers with states/memory), which extends their perceptive field to unbounded contexts (https://www.youtube.com/watch?v=5AoOpFFjW28&list=PL0NRmB0fnLJQJ3fuIk3yVULtm6_JnQ_zI, https://arxiv.org/abs/2203.08913). I am curious to hear what you think about how S4/Sashimi might compare with this new transformer model. My hunch is that S4 might be theoretically similar if you use the exponential measure density.

Training on .mat files

Hi, I am trying to train the Sashimi model on .mat files. How should I go about doing this?

Can't compile the custom cauchy kernel

Dear all,
Sorry for a silly question. I'm having trouble install cauchy kernel by custom cuda kernel.
running python setup.py install like this:

~/LongSeq/state-spaces/extensions/cauchy$ python setup.py install
running install
~/miniconda3/envs/lightning/lib/python3.9/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
~/miniconda3/envs/lightning/lib/python3.9/site-packages/setuptools/command/easy_install.py:156: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
running bdist_egg
running egg_info
creating cauchy_mult.egg-info
writing cauchy_mult.egg-info/PKG-INFO
writing dependency_links to cauchy_mult.egg-info/dependency_links.txt
writing top-level names to cauchy_mult.egg-info/top_level.txt
writing manifest file 'cauchy_mult.egg-info/SOURCES.txt'
reading manifest file 'cauchy_mult.egg-info/SOURCES.txt'
writing manifest file 'cauchy_mult.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
warning: install_lib: 'build/lib' does not exist -- no Python modules to install

creating build
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying cauchy_mult.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
zip_safe flag not set; analyzing archive contents...
creating dist
creating 'dist/cauchy_mult-0.0.0-py3.9.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing cauchy_mult-0.0.0-py3.9.egg
Copying cauchy_mult-0.0.0-py3.9.egg to ~/miniconda3/envs/lightning/lib/python3.9/site-packages
Adding cauchy-mult 0.0.0 to easy-install.pth file

Installed ~/miniconda3/envs/lightning/lib/python3.9/site-packages/cauchy_mult-0.0.0-py3.9.egg
Processing dependencies for cauchy-mult==0.0.0
Finished processing dependencies for cauchy-mult==0.0.0

but cannot import the cauchy_mult when I running test_cauchy.py

~/LongSeq/state-spaces/extensions/cauchy$ python test_cauchy.py 
Traceback (most recent call last):
  File "~/LongSeq/state-spaces/extensions/cauchy/test_cauchy.py", line 8, in <module>
    from cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
  File "~/LongSeq/state-spaces/extensions/cauchy/cauchy.py", line 5, in <module>
    from cauchy_mult import cauchy_mult_fwd, cauchy_mult_bwd, cauchy_mult_sym_fwd, cauchy_mult_sym_bwd
ModuleNotFoundError: No module named 'cauchy_mult'

I also try to run python -m pip install and change python to 3.8, but they don't work too.

Resuming suspended training

Hey folks,

Congrats on the amazing and inspiring work!!

I have a quick question - how do I resume training + wandb logging if the training got terminated before completion. E.g. say the original command was CUDA_VISIBLE_DEVICES=0 python -m train experiment=s4-lra-pathx loader.batch_size=16 trainer.accumulate_grad_batches=2 and the training got suspended halfway through the full training. What command should I run to resume the training and the wandb logging?

Thanks in advance,
Ankit

RNN-style train and eval for S4/S4D

Excellent idea and great paper!
Could you please provide a concrete example on how to both train and eval using the stateful RNN version of S4/S4D? I only find an evaluation example in the SaShiMi code but I have not found an example for training.
Thank you!

Error while using DataParrallel

Hello,
Thanks for sharing this awesome work, when i try to run example.py i get the following error :

CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running python setup.py install. This should speed up end-to-end training by 10-50% Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency. cuda ==> Preparing cifar10 data.. Files already downloaded and verified Files already downloaded and verified Files already downloaded and verified ==> Building model.. Optimizer group 0 | 28 tensors 0it [00:11, ?it/s] | 0/200 [00:00<?, ?it/s] Epoch: 0: 0%| | 0/200 [00:12<?, ?it/s] Traceback (most recent call last): File "example.py", line 373, in <module> train() File "example.py", line 312, in train outputs = model(inputs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "example.py", line 196, in forward z, _ = layer(z) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 842, in forward k = self.kernel(L=L) # (C H L) (B C H L) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 752, in forward k = self.kernel(L=L) File "/media/data2/ameenali/anaconda3/envs/ameen/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/media/data1/ameen/state-spaces/src/models/sequence/ss/standalone/s4.py", line 427, in forward C = _r2c(self.C) RuntimeError: Output 3 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
Any idea how to solve this?

can't run train.py w/o compiling Cauchy kernel for CUDA

Dear all,

I am having trouble compiling the Cauchy kernel, and although I have installed pykeops, running train.py always results in errors like this:

_RuntimeError: [KeOps] This KeOps shared object has been compiled without cuda support:

  1. to perform computations on CPU, simply set tagHostDevice to 0
  2. to perform computations on GPU, please recompile the formula with a working version of cuda._

The only thing that fixed the issues for me is commenting out the following try/catch. Without that (sorry for its uglyness...) the code never did default back to the slow kernel... now it does, but that is certainly not the right way for me to go about it ;)

I wonder if the try/catch-phrase needs to check whether the kernel actually runs, not just lets itself be imported?

''' try:
import pykeops
from src.models.functional.cauchy import cauchy_conj
has_pykeops = True
except ImportError:
has_pykeops = False
from src.models.functional.cauchy import cauchy_conj_slow
if not has_cauchy_extension:
log.error(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)
'''

has_pykeops = False
from src.models.functional.cauchy import cauchy_conj_slow
if not has_cauchy_extension:
log.error(
"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency."
)

How to visualize the results of prediction?

Hi,

Thank you for your sharing.

I am doing some experiments on Ett dataset using S4.
I wonder how to load the best checkpoint, test it with new data, and then visualize the prediction and truth.

Is there any easy way by pytorch lighting?

Cheers,
Max

Training on own data

I'd like to train a sashimi model on my own data. Could you please let me know in what structure the data needs to be in order to be compatible with the dataloder?

Conditioning on the diffusion step

Hello,

I'm considering trying your Sashimi model as the backbone of a diffusion model for audio generation. There is a detail I couldn't find in the paper, neither in the code (maybe I didn't look enough). How do you condition the model to the diffusion timestep? Do you do the same as in diffwave? (Element-wise addition of the diffusion-step embedding at each layer) Or use something similar to a FiLM layer , as in WaveGrad?

S4D Memory Requirements

Hey, I wanted to give S4D a quick try in my research as a drop-in replacement of S4 (which, as far as I gathered, should be a good way to start), but I'm running into some hard memory limitations. I'm trying to train the DiffWave version of SaShiMi as a first experiment, but the memory requirements seem to increase significantly when replacing S4 with an equivalent S4D layer (with default settings), causing the model to go OOM in my case actually (so I don't have any precise measurements, but it's a 20% increase in overall memory consumption at least. I use the parameters as discussed in #46. Is this something you'd expect?

SaShiMi generation script errors out with own models

Hey, first of all, great work with the repository, I don't think I've worked with a repository for a paper that's so extensive and well-structured so far.

I'm currently trying to train the SaShiMi model on my own dataset (following your guide here: #23), and I run into some issues when trying to generate samples with the trained model.
In case this is relevant, I'm trying to do inference on the checkpoint files, and I changed the number of layers (model.n_layers) to 4 to accommodate for the memory limitations of my GPU. Apart from that, I have done no changes to any of the training and model (code) except for switching the dataset to my own.
When I try to call the generation.py script now, I run into a range of errors:

  • The config overrides cause some errors, namely the hurwitz parameter does not exist anymore, and the setup_step methods don't seem to correctly accept (or rather pass them downstream) the mode argument. I "fixed" this by removing the hurwitz argument override and by adding the mode argument to all module.setup_step() methods and just passing it downstream as required.
  • Additionally, setting model.layer.postact=null causes the state_dict to not load successfully anymore, giving me the following error:
Missing key(s) in state_dict: "model.c_layers.0.layer.output_linear.weight", "model.c_layers.0.layer.output_linear.bias", "model.c_layers.2.layer.output_linear.weight", "model.c_layers.2.layer.output_linear.bias", "model.c_layers.4.layer.output_linear.weight", "model.c_layers.4.layer.output_linear.bias", "model.c_layers.6.layer.output_linear.weight", "model.c_layers.6.layer.output_linear.bias", "model.u_layers.0.1.layer.output_linear.weight", "model.u_layers.0.1.layer.output_linear.bias", "model.u_layers.0.3.layer.output_linear.weight", "model.u_layers.0.3.layer.output_linear.bias", "model.u_layers.0.5.layer.output_linear.weight", "model.u_layers.0.5.layer.output_linear.bias", "model.u_layers.0.7.layer.output_linear.weight", "model.u_layers.0.7.layer.output_linear.bias", "model.u_layers.1.1.layer.output_linear.weight", "model.u_layers.1.1.layer.output_linear.bias", "model.u_layers.1.3.layer.output_linear.weight", "model.u_layers.1.3.layer.output_linear.bias", "model.u_layers.1.5.layer.output_linear.weight", "model.u_layers.1.5.layer.output_linear.bias", "model.u_layers.1.7.layer.output_linear.weight", "model.u_layers.1.7.layer.output_linear.bias". 
Unexpected key(s) in state_dict: "model.c_layers.0.layer.output_linear.0.weight", "model.c_layers.0.layer.output_linear.0.bias", "model.c_layers.2.layer.output_linear.0.weight", "model.c_layers.2.layer.output_linear.0.bias", "model.c_layers.4.layer.output_linear.0.weight", "model.c_layers.4.layer.output_linear.0.bias", "model.c_layers.6.layer.output_linear.0.weight", "model.c_layers.6.layer.output_linear.0.bias", "model.u_layers.0.1.layer.output_linear.0.weight", "model.u_layers.0.1.layer.output_linear.0.bias", "model.u_layers.0.3.layer.output_linear.0.weight", "model.u_layers.0.3.layer.output_linear.0.bias", "model.u_layers.0.5.layer.output_linear.0.weight", "model.u_layers.0.5.layer.output_linear.0.bias", "model.u_layers.0.7.layer.output_linear.0.weight", "model.u_layers.0.7.layer.output_linear.0.bias", "model.u_layers.1.1.layer.output_linear.0.weight", "model.u_layers.1.1.layer.output_linear.0.bias", "model.u_layers.1.3.layer.output_linear.0.weight", "model.u_layers.1.3.layer.output_linear.0.bias", "model.u_layers.1.5.layer.output_linear.0.weight", "model.u_layers.1.5.layer.output_linear.0.bias", "model.u_layers.1.7.layer.output_linear.0.weight", "model.u_layers.1.7.layer.output_linear.0.bias".

Does this mean that I should rename those keys manually (there's a fairly clear correspondence) to make it work after changing the activation?

  • Finally, even when I pass through the mode parameter in module.setup_step(), I still get this error:
Traceback (most recent call last):
  File "/home/debaumas/state-spaces/sashimi/generation.py", line 192, in main
    module.setup_step(mode='dense')
  File "/home/debaumas/state-spaces/src/models/sequence/ss/kernel.py", line 1038, in setup_step
    self.kernel.setup_step(mode=mode)
  File "/home/debaumas/state-spaces/src/models/sequence/ss/kernel.py", line 515, in setup_step
    dC = torch.linalg.solve(
torch._C._LinAlgError: linalg.solve: (Batch element 0): The diagonal element 1 is zero, the solve could not be completed because the input matrix is singular.

Do you have any idea what might be causing this and maybe an idea about how to fix/circumvent this?

It'd be awesome if you could help point me in the right direction with this.

Best,
Stefan

Experiment reproduction issue with updated modules

Hi, I was trying to reproduce some of your results using the SaShiMi model by running the command

python -m train experiment=sashimi-sc09 wandb=null

but I get the error

TypeError: __init__() got an unexpected keyword argument 'pool'

due to the DownPool class no longer needing pool parameter for initialization.

Can I ask if there are any plans to fix these issues so that they work with the current implementations of the different modules?

Should the A, B, C, dt parameters be set to be trainable?

Hi!

When I was checking the code, I found that for most of the experimental configurations for S4, the A, B, C, dt parameters are not set to be trainable, while I originally thought that these parameters are trained in S4 presented in the paper. I don't know if I understand it correctly, but isn't this setting similar to HiPPO if these parameters are not trained? Or do you have any empirical findings on this?

Thanks!

Model can not converge on the LRA Pathfinder

Hi,

Thanks for the great work! When I ran your code on the LRA pathfinder dataset (using your config), I found it can't converge till the end of the 200th epoch as shown in the following log: loss=0.693, val/accuracy=0.499, val/loss=0.693, test/accuracy=0.495, test/loss=0.693, train/accuracy=0.501, train/loss=0.693. The loss is 0.693 throughout training.

Do you have any thoughts on this? Thanks!

S4 Module Distribution

I think it would be quite useful if one could install this repository, either directly from git or preferably using pip.

I invision a change to models where people only have to:

  1. pip install state-spaces
  2. from state_spaces import S4
  3. replace use of nn.LSTM or nn.Transformer with S4

There is a package on pypi, not sure if it was pushed by you or not, but the latest version is older than the current code.
https://pypi.org/project/state-space/#history

sashimi: LinearActivation initialization gives TypeError for weight_norm argument

Hi,

I bounced upon an issue while attempting to run the sashimi.py script, namely the LinearActivation function in the ./standalone/s4.py script does not accept the weight_norm argument upon initialization passed on through the **kwargs in the DownPool and UpPool classes.

The error trace:

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 450, in <module>
    model = Sashimi(n_layers=2).cuda()

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 298, in __init__
    d_layers.append(DownPool(H, expand, p))

File ".\PycharmProjects\state-spaces\sashimi\sashimi.py", line 28, in __init__
    self.linear = LinearActivation(

File ".\PycharmProjects\state-spaces\src\models\sequence\ss\standalone\s4.py", line 137, in LinearActivation
    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)

TypeError: __init__() got an unexpected keyword argument 'weight_norm'

Looking a bit further I noticed that neither the nn.Conv1d (chosen in DownPool due to transposed = True) nor the nn.Linear that could be called within LinearActivation(), have the explicit weight_norm argument in Pytorch.

Am I overlooking something?

Python: 3.9
Pytorch: 1.12 (latest stable release)

Thanks a lot for publishing your code with the papers!

Cheers,
Bavo

Shape '[]' is invalid for input of size

When I run:
python -m train wandb=null experiment=s4-cifar

I meet this problem:
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/autograd/init.py", line 154, in backward
Variable._execution_engine.run_backward(
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
return user_fn(self, *args)
File "/root/anaconda3/envs/statespaces/lib/python3.8/site-packages/pykeops/torch/generic/generic_red.py", line 263, in backward
grad = grad.reshape(
RuntimeError: shape '[1024, 2, 2, 32, 2]' is invalid for input of size 4202496

    Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Faster Cauchy Kernels?

Breathtaking work, absolutely amazing application of linear algebra. Beautiful.

A few questions.

Appendix of S4 article mentions "... implementation of S4 uses the naive O(NL) algorithm ..._ " and README.md mentions custom kernels.

Question 0. Did you benchmark the naive O(NL) against the custom kernel?

Question 1. Is the 60x speedup in Table 8 with naive O(NL) or custom kernel? Or is the custom kernel only used during training?

Question 2. How big a percentage of compute is spent on S4 compared to mlp/lnorm/others in generation mode?

Apologies for any misunderstandings

standalone S4 module usage

Hi! I really want to try S4D in my research, but first I want to make a little proof of concept for myself, without going into too much theory or detail.

Do I understand correctly that standalone models can be used out of the box? What bothers me: in your publications, you talk about the importance of model initializing and tuning optimizer parameters. Is proper initialization taken into account in standalone S4D? Is there anywhere to see how to properly set up the optimizer for standalone S4D?

Thanks for your research!

S4 for Seq2seq tasks?

Hello,
You have shown S4's great ability and efficiency on classification and unconditional generation (wikitext-103) tasks. I am wondering if S4 can be applied to conditional generation tasks such as summarization and machine translation? A simple idea is to re-organize these tasks to language modeling tasks, but I am not sure whether the generation quality would be affected.

ValueError when running on Pathfinder

Hi, I am getting the following error when trying to train S4 on the pathfinder dataset. Any help would be greatly appreciated.

Traceback (most recent call last):
File "/data/al451/state-spaces/train.py", line 553, in main
train(config)
File "/data/al451/state-spaces/train.py", line 498, in train
trainer.fit(model)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
self._call_and_handle_interrupt(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1492, in _call_setup_hook
self._call_lightning_module_hook("setup", stage=fn)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/data/al451/state-spaces/train.py", line 56, in setup
self.dataset.setup()
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1234, in setup
dataset = PathFinderDataset(self.data_dir, transform=self.default_transforms())
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1130, in init
path_list = sorted(
File "/data/al451/state-spaces/src/dataloaders/datasets.py", line 1132, in
key=lambda path: int(path.stem),
ValueError: invalid literal for int() with base 10: '._142'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Time-series experiments

The experiments using datasets like etth, ettm, and ecl don't run (despite the definitions in /config). I think it's because the datasets cannot be downloaded automatically like the LRA experiments.

Would it be possible to add this, or explain how these experimets could be run?

A has become time-invariant compared to HIPPO

Hi,

Thanks for your excellent work!
In the HIPPO paper, the transition matrix A of SSM is time-variant, but has become time-invariant in s4 model. I didn't find any theories/experiments that discuss why it still works.
Can you kindly share your thoughts?

Thanks again,
Ziwei

PyTorch only cauchy kernel for easier test

Hello,

I loved Structured State Spaces, and obteined a fantastic performance compared to LSTM/SRU/Transformers.

I want to introduce S4 to some researchers and students, and the self-contained S4 layer is super great! However it requires the "cautchy kernel".

There are two versions of cautchy kernel, a cuda and a Pykeops version.

However extensions/cauchy requires cuda and Pykeops do not support Windows, and the target people have a very diverse number of environments.

Would be possible to HazyResearch team to implement a self-contained S4 layer including a simpler pytorch cautchy kernel ?

Error when running example.py

Dear all, many thanks for this extremely interesting code (and maths)!

Running
example.py --grayscale
with python 3.8 gives the following error on my system:
"RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."

Maybe some requirements need specific versions for the example to run?

load_dataset process getting Killed

  ''' dataset = load_dataset(
        "csv",
        data_files={
            "train": str(self.data_dir / "basic_train.tsv"),
            "val": str(self.data_dir / "basic_val.tsv"),
            "test": str(self.data_dir / "basic_test.tsv"),
        },
        delimiter="\t",
        keep_in_memory=True,
    )''' this piece of code in src/dataloaders/datasets.py is giving error and the dataloading process is aborted

Error Running Basic Test Script with v2 Tag (Works With v1)

Hi there,

I recently tried upgrading my S4 setup / environment to be on the v2 tag but ran into the following issue when running the basic test script:

(base) ray@test-python:~/state-spaces$ python -m train wandb=null pipeline=mnist model=s4
CONFIG
├── train
│   └── seed: 0                                                                                                                                                                                        
│       interval: epoch                                                                                                                                                                                
│       monitor: val/accuracy                                                                                                                                                                          
│       mode: max                                                                                                                                                                                      
│       ema: 0.0                                                                                                                                                                                       
│       test: false                                                                                                                                                                                    
│       debug: false                                                                                                                                                                                   
│       ignore_warnings: false                                                                                                                                                                         
│       state:                                                                                                                                                                                         
│         mode: null                                                                                                                                                                                   
│         chunk_len: null                                                                                                                                                                              
│         overlap_len: null                                                                                                                                                                            
│         n_context: 0                                                                                                                                                                                 
│         n_context_eval: 0                                                                                                                                                                            
│       sweep: null                                                                                                                                                                                    
│       group: null                                                                                                                                                                                    
│       benchmark_step: false                                                                                                                                                                          
│       benchmark_step_k: 1                                                                                                                                                                            
│       benchmark_step_T: 1                                                                                                                                                                            
│       checkpoint_path: null                                                                                                                                                                          
│       visualizer: filters                                                                                                                                                                            
│       disable_dataset: false                                                                                                                                                                         
│                                                                                                                                                                                                      
├── wandb
│   └── None                                                                                                                                                                                           
├── trainer
│   └── gpus: 1                                                                                                                                                                                        
│       accumulate_grad_batches: 1                                                                                                                                                                     
│       max_epochs: 200                                                                                                                                                                                
│       gradient_clip_val: 0.0                                                                                                                                                                         
│       log_every_n_steps: 10                                                                                                                                                                          
│       limit_train_batches: 1.0                                                                                                                                                                       
│       limit_val_batches: 1.0                                                                                                                                                                         
│       weights_summary: top                                                                                                                                                                           
│       progress_bar_refresh_rate: 1                                                                                                                                                                   
│       track_grad_norm: -1                                                                                                                                                                            
│       resume_from_checkpoint: null                                                                                                                                                                   
│                                                                                                                                                                                                      
├── loader
│   └── batch_size: 50                                                                                                                                                                                 
│       num_workers: 4                                                                                                                                                                                 
│       pin_memory: true                                                                                                                                                                               
│       drop_last: true                                                                                                                                                                                
│       train_resolution: 1                                                                                                                                                                            
│       eval_resolutions:                                                                                                                                                                              
│       - 1                                                                                                                                                                                            
│                                                                                                                                                                                                      
├── dataset
│   └── _name_: mnist                                                                                                                                                                                  
│       permute: true                                                                                                                                                                                  
│       val_split: 0.1                                                                                                                                                                                 
│       seed: 42                                                                                                                                                                                       
│                                                                                                                                                                                                      
├── task
│   └── _name_: base                                                                                                                                                                                   
│       loss: cross_entropy                                                                                                                                                                            
│       metrics:                                                                                                                                                                                       
│       - accuracy                                                                                                                                                                                     
│       torchmetrics: null                                                                                                                                                                             
│                                                                                                                                                                                                      
├── optimizer
│   └── _name_: adamw                                                                                                                                                                                  
│       lr: 0.001                                                                                                                                                                                      
│       weight_decay: 0.0                                                                                                                                                                              
│                                                                                                                                                                                                      
├── scheduler
│   └── _name_: plateau                                                                                                                                                                                
│       mode: max                                                                                                                                                                                      
│       factor: 0.2                                                                                                                                                                                    
│       patience: 20                                                                                                                                                                                   
│       min_lr: 0.0                                                                                                                                                                                    
│                                                                                                                                                                                                      
├── encoder
│   └── linear                                                                                                                                                                                         
├── decoder
│   └── _name_: sequence                                                                                                                                                                               
│       mode: pool                                                                                                                                                                                     
│                                                                                                                                                                                                      
├── model
│   └── layer:                                                                                                                                                                                         
│         _name_: s4                                                                                                                                                                                   
│         d_state: 64                                                                                                                                                                                  
│         channels: 1                                                                                                                                                                                  
│         bidirectional: false                                                                                                                                                                         
│         activation: gelu                                                                                                                                                                             
│         postact: null                                                                                                                                                                                
│         hyper_act: null                                                                                                                                                                              
│         dropout: 0.0                                                                                                                                                                                 
│         measure: legs                                                                                                                                                                                
│         rank: 1                                                                                                                                                                                      
│         dt_min: 0.001                                                                                                                                                                                
│         dt_max: 0.1                                                                                                                                                                                  
│         trainable:                                                                                                                                                                                   
│           dt: true                                                                                                                                                                                   
│           A: true                                                                                                                                                                                    
│           P: true                                                                                                                                                                                    
│           B: true                                                                                                                                                                                    
│         lr: 0.001                                                                                                                                                                                    
│         length_correction: true                                                                                                                                                                      
│         tie_state: true                                                                                                                                                                              
│         hurwitz: true                                                                                                                                                                                
│         resample: false                                                                                                                                                                              
│         deterministic: false                                                                                                                                                                         
│         l_max: 784                                                                                                                                                                                   
│         verbose: false                                                                                                                                                                               
│       _name_: model                                                                                                                                                                                  
│       prenorm: false                                                                                                                                                                                 
│       transposed: true                                                                                                                                                                               
│       n_layers: 4                                                                                                                                                                                    
│       d_model: 256                                                                                                                                                                                   
│       residual: R                                                                                                                                                                                    
│       pool:                                                                                                                                                                                          
│         _name_: sample                                                                                                                                                                               
│         pool: 1                                                                                                                                                                                      
│         expand: 1                                                                                                                                                                                    
│       norm: layer                                                                                                                                                                                    
│       dropout: 0.0                                                                                                                                                                                   
│                                                                                                                                                                                                      
└── callbacks
    └── learning_rate_monitor:                                                                                                                                                                         
          logging_interval: epoch                                                                                                                                                                      
        timer:                                                                                                                                                                                         
          step: true                                                                                                                                                                                   
          inter_step: false                                                                                                                                                                            
          epoch: true                                                                                                                                                                                  
          val: true                                                                                                                                                                                    
        params:                                                                                                                                                                                        
          total: true                                                                                                                                                                                  
          trainable: true                                                                                                                                                                              
          fixed: true                                                                                                                                                                                  
        model_checkpoint:                                                                                                                                                                              
          monitor: val/accuracy                                                                                                                                                                        
          mode: max                                                                                                                                                                                    
          save_top_k: 1                                                                                                                                                                                
          save_last: true                                                                                                                                                                              
          dirpath: checkpoints/                                                                                                                                                                        
          filename: val/accuracy                                                                                                                                                                       
          auto_insert_metric_name: false                                                                                                                                                               
          verbose: true                                                                                                                                                                                
                                                                                                                                                                                                       
Global seed set to 0
[2022-05-25 13:40:50,814][__main__][INFO] - Instantiating callback <src.callbacks.timer.Timer>
[2022-05-25 13:40:50,815][__main__][INFO] - Instantiating callback <src.callbacks.params.ParamsLog>
[2022-05-25 13:40:50,816][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
[2022-05-25 13:40:50,848][torch.distributed.nn.jit.instantiator][INFO] - Created a temporary directory at /tmp/tmpm51hqe7x
[2022-05-25 13:40:50,849][torch.distributed.nn.jit.instantiator][INFO] - Writing /tmp/tmpm51hqe7x/_remote_module_non_sriptable.py
Error executing job with overrides: ['wandb=null', 'pipeline=mnist', 'model=s4']
Traceback (most recent call last):
  File "/home/ray/state-spaces/train.py", line 553, in main
    train(config)
  File "/home/ray/state-spaces/train.py", line 498, in train
    trainer.fit(model)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
    self._call_and_handle_interrupt(
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1492, in _call_setup_hook
    self._call_lightning_module_hook("setup", stage=fn)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ray/state-spaces/train.py", line 74, in setup
    self.model = utils.instantiate(registry.model, self.hparams.model)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/model.py", line 69, in __init__
    block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, layer=layer, residual=residual, norm=norm, pool=pool)
  File "/home/ray/state-spaces/src/models/sequence/block.py", line 36, in __init__
    self.layer = utils.instantiate(registry.layer, layer, d_input)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/ss/s4.py", line 86, in __init__
    self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args)
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 712, in __init__
    self.kernel = SSKernelNPLR(
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 217, in __init__
    self.C = nn.Parameter(_c2r(_resolve_conj(C)))
RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.  To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.

Is this something you've seen before? I'd be happy to provide a fuller description of my package version, system architecture, etc. if you can let me know what might help get to the bottom of this bug.

Best,
Matthew

n_epoch_double

Hello thank you for the great library. I was wondering about the n_epoch_double flag and the best way to use it. If you use that flag will it not just give you a cuda OOM error when it doubles? Or is the appropriate usage to begin with a very small batch size so that it has room to double without going OOM? When running the wt103 model if I went over L_max=4096 with a batch size of 1 on a 4xV100 32GB GPU machine I ran out of memory when it got to the eval stage of the first epoch. I have a language modeling use case that requires very long sequence lengths (8192+) and wanted to try the s4 model on it, because the results seem pretty good for lower sequence lengths. Any help would be appreciated. Thank you!

decoding LM vocab

hello, I trained a model using something like the wt103 task and modified the sashimi generate script to generate text like a CLM. So basically conditioning on a text string, generate the next N words sequentially in the same loop like the Sashimi generation script. I believe that I have it working however I don't know what integer output corresponds to what word in the vocab. Is there a hash table or something that stores the vocab somewhere that's easily accessible? Sorry I can't seem to find any obvious place that it would reside. Thank you for your help.

how to save and load checkpoints

Hi,
I'm running experiments with default settings using this command:

python -m train wandb=null experiment=sashimi-sc09

I'd like to get the checkpoint of the best model in the training process and then load it to generate .wav files. But i find the only checkpoint generated is a last.ckpt under the directory outputs/yyyy-mm-dd/xx-xx-xx/checkpoints, and i have no idea how to load it (since the checkpoint loading in sashimi/generation.py loads a .pt file)

Test Accuracy

Hello,

Sorry for another silly question. While training on s4-lra-cifar (on A100) I get the following final log:

Epoch 94, global step 85499: val/accuracy was not in top 1
Epoch 96:  80% 966/1200 [01:16<00:18, 12.63it/s, loss=0.144, v_num=xyj4, val/accuracy=0.866, val/loss=0.527, test/accuracy=0.865, test/loss=0.533, train/accuracy=0.946, train/loss=0.151]
Epoch 95, global step 86399: val/accuracy reached 0.86580 (best 0.86580), saving model to "blah/checkpoints/val/accuracy.ckpt" as top 1
Epoch 97:  78% 938/1200 [01:15<00:21, 12.38it/s, loss=0.142, v_num=xyj4, val/accuracy=0.864, val/loss=0.536, test/accuracy=0.864, test/loss=0.539, train/accuracy=0.945, train/loss=0.156]
Epoch 96, global step 87299: val/accuracy was not in top 1
Epoch 98:  75% 900/1200 [01:14<00:24, 12.06it/s, loss=0.143, v_num=xyj4, val/accuracy=0.866, val/loss=0.529, test/accuracy=0.865, test/loss=0.535, train/accuracy=0.945, train/loss=0.153]
Epoch 97, global step 88199: val/accuracy was not in top 1
Epoch 99:  79% 945/1200 [01:15<00:20, 12.45it/s, loss=0.137, v_num=xyj4, val/accuracy=0.863, val/loss=0.534, test/accuracy=0.864, test/loss=0.539, train/accuracy=0.946, train/loss=0.151]
Epoch 98, global step 89099: val/accuracy was not in top 1

Epoch 99, global step 89999: val/accuracy was not in top 1
Saving latest checkpoint...
  1. Does this mean that the test accuracy measured at the checkpoint with the best validation accuracy is 86.5?
  2. If no, what command should I use to measure the test accuracy that you report in the paper? (I am assuming that the accuracy you report in the paper is indeed the test acc at the checkpoint with best val acc)
  3. You report 87.26 on lra-cifar - is this normal to get this gap? I hope I'm not making a mistake in interpreting the metrics.

I would be grateful if you could help with these questions and apologize in advance if they seem too basic.

Thank you again for sharing you wonderful repo,
Ankit

Multi GPU training

Hi,

Is there any way to train S4 with multiple GPUs? I have 2 GPUs, but only one of them is working.

Thanks

Inconsistent results of forward (training) and step (inference)

Hi, I did a simple test to verify the difference between forward and step (mode="dense") on a single unidirectional S4 layer. Given a random sequence, there difference, the absolute error is around 1e-2 and the square error is around 1e-4. I suspect these results are wrong. My verification follows test_step() in //src/models/sequence/ss/kernel.py. I'd love to know if you have examples that clearly compares their difference. Thanks:)

request to load logs for experiments

Hi, can you please upload the logs of the experiments that were reported in the paper?

I tried to reproduce the wikitext-103 experiment but had to change some configurations due to hardware constraints. Even though the changes were minor, the results were not as I expected them to be. I think that the logs from the original experiments might help me to reproduce the results more easily.

Thank you so much!

minGPT like training

Hi! Very impressive results!

I'd like to try applying the S4 model on some toy example from NLP, like text generation (replicate examples in minGPT from @karpathy). I'm not very familiar with state-space models, so I don't understand a few things and have few questions:

  1. As far as I understand, such model doesn't need positional encodings/embeddings?
  2. How to properly train such a model in causal mode, that is, so that the model doesn't look into the future? Is there some equivalent to masking in Transformers? Or it's default mode out of box (like in vanilla rnn)?

Thanks!

WikiText-103

Hi, I'm interested in recreating your WikiText-103 LM experiment. Is it possible you could make that easier for me? Thanks! CJ

Error when training on youtubemix

When running CUDA_VISIBLE_DEVICES=1,2,3,4,7 python -m train wandb=null experiment=sashimi-youtubemix dataset=youtubemix, I get the following error:

Traceback (most recent call last):
File "/data/al451/state-spaces/train.py", line 553, in main
train(config)
File "/data/al451/state-spaces/train.py", line 498, in train
trainer.fit(model)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
self._call_and_handle_interrupt(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
self._evaluation_loop._reload_evaluation_dataloaders()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 170, in _reload_evaluation_dataloaders
self.trainer.reset_val_dataloader()
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 551, in reset_val_dataloader
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py", line 508, in _reset_eval_dataloader
if has_len_all_ranks(dataloader, self.training_type_plugin, module)
File "/home/al451/.local/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py", line 118, in has_len_all_ranks
raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: Total length of Dataloader across ranks is zero. Please make sure that it returns at least 1 batch.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Conceptual Questions regarding S4/HiPPO

Dear authors/contributors,

First of all, thank you so much for publishing such a great work. I think it is really inspirational and we will see this model (or its variants) being deployed to solve variety of real-world problems in the next years.

I tried to go through your most recent papers starting from HiPPO, and I would like to kindly ask conceptual questions to deepen my understanding. As I couldn't find different sources of information other than your papers (and a couple of your recorded talks on Youtube and Annotated S4), I think this could be an appropriate place to ask those questions. If you prefer any other discussion platform, please let me know.

PS: These questions turned out to be a bit longer than I intended, but I don't expect you to clarify them all at once :)

  1. A matrix <—> Polynomial basis : From my understanding about your HiPPO paper, you derive the A matrix for various measures and polynomial bases. Therefore, for a given A (and hence given polynomial basis), we know how to reconstruct the original signal u(t) based on the state/coefficients x(t). My question is: What does the model learn when we initialize A as HiPPO but train it over time (i.e. when A is not fixed)? In other words, how does the polynomial basis change in that sense and how does the model have the ability to reconstruct the original signal u(t) with varying A?
  2. Learning the step size: In annotated S4, the step size is another parameter that’s learned through the training. (I am not sure if you do the same as I couldn't go over your code yet)
    • May I ask the intuition for why we learn this step size and what is its potential effect(s)? For instance, if we use a measure that is exponentially decaying over time, can we say that larger step size leads to prioritizing more recent history and smaller step size is better for giving more weight to a distant past (because its weight will decay smaller)?
    • If we work on a signal that has a natural sense of time (i.e. ECG signal) should we still make step size trainable (in first and all the intermediate layers) since the actual formulation (to my understanding) has no notion of the units of step size (e.g. seconds or days etc.)?
  3. Irregular sampling of time series. I am convinced by the continuous-time view of S4 that it can naturally handle the irregularly-sampled time series of an underlying continuous dynamics. However, I am confused by the discretization step where we leverage convolution for training and recurrence for fast inference. If I have an irregular time series, how can I train S4?
    • Small comment: I think if the training data is regularly sampled, we can still handle irregular time series in real-time inference based on the bilinear transform of A_bar, B_bar etc. into their continuous equivalent. Is that true?
  4. The effect of "deep" S4 layers. In Figure 2 of your paper “Efficiently Modeling Long Sequences with Structured State Spaces”, we see the visualization of the kernels for Path-X task for the first and last layers. We see that (mostly) first layers are for local context vs. last layers are for more global context. Why is it the case if HiPPO offers continuous-time memorization? In other words, why can’t it memorize the distant past in the first layers and why does it need stacking more layers to aggregate more context from the past? I assume it is related to a chosen measure and/or the step size itself, but I am really curious about your opinion.
    • For deep CNN-related models, we have the explanation that the receptive field grows with stacking more and more layers. (Field grows exponentially with dilated convolutions like TCN, and linearly for some other types). Is there any analogy or similar explanation for S4?

It is a great pleasure for me to know more about your exciting work. Many thanks in advance. I would be also happy to know if there are other resources that you can suggest.

when running for language model exception happend

when I tried to run the code for the language model (wt103) with the command

HYDRA_FULL_ERROR=1 python3 -m train wandb=null experiment=s4-wt103

the exception happened!

in adaptive_softmax.py module, the left index is greater than r_idx so nn.Parameter(torch.zeros(abs(r_idx - l_idx))) can't pass (line 116)! so I did used abs() function to pass it.
then in self.out_layers_weights (line 182) again exception arises. list index out of range so I did used try-except to pass the lines...

You can find the full error in the attached file:
error.txt

Unused parameters in training

Hi! I'm running some experiments using your code. For my use-case, I'm using torch.nn.DistributedDataParallel, which automatically detects unused parameters, i.e., parameters that get no gradients.

The unused parameters are:

  • D (from the S4 module)
  • output_linear.weight and output_linear.bias (from the S4 module). These are instances of the TransposedLinear layer.
  • kernel.C (from SSKernelNPLR).

I have manually confirmed these parameters don't get gradients by running the following code after computing the loss:

for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

Usually, the above means the parameters are instantiated but not used. In this case, surprisingly, all the parameters get used in the forward method. However, none of them get used in "vanilla" PyTorch ops. D, output_linear.weight and output_linear.bias get used through opt_einsum.contract, and kernel.C gets used through your Cauchy GPU op.

Can you confirm the issue on your end? These parameters all look important for the model.

Replication of Diffusion Results

Hi, I'm trying to replicate your results for applying SaShiMi in a diffusion context, and have run into some questions about implementation details along the way. It'd be awesome if you could help me out with them.

  1. I have found the diffusion version of the SaShiMi model at https://github.com/HazyResearch/state-spaces/blob/diffwave/sashimi/sashimi.py. I assume that one is the reference implementation. If yes, what parameters did you use? Just bidirectional=True, unet=True, diffwave=True and set the rest to the values specified in Appendix C.2.2 of the paper and their respective default values?
  2. In the original model, you use mu-law quantization for the model. Is this something you also use with the diffusion implementation? And are you using an embedding encoder & sequence decoder like for the AR model? If so, how are you implementing this setup, also in regards to e.g. the additive noise?

Best,
Stefan

Dropout2d and residual

Dear authors and contributors,

There is an observation that I would be happy to get your confirmation on :-)
In all of the model hierarchy: SequenceModel, SequenceResidualBlock and S4 ,you are using Dropout2d which zeros at the batch dimension, i.e. ignores the entire sample. Without a residual link, with multiple layers, the probability that each sample is not ignored through the model becomes negligible. Consequently, the model does not see the inputs and will not train!
In the SequenceResidualBlock, the dropout is applied only if a residual link is present. The residual link of SequenceResidualBlock also takes care of the dropout from S4.
So my issue is two-fold:

  • When using dropout > 0, we never should set residual = None in the parameters of SequenceResidualBlock, right? Is it possible to add a check in the initialization to avoid possible misconfigurations?
  • The dropinp input of SequenceModel should not be used, as there is no residual link there. I've seen in all of the configs we have dropinp: 0.0. So why is it there at all?

Thanks and regards,

GPU Out of Memory

I was wondering what parameters I could change to be able to run it on GPU with limited RAM. I tried reducing the layers to 4, which did not help. Also, it seems like batch size is set to 1 by default. I am using 4x TITAN RTX 24GB.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.