Code Monkey home page Code Monkey logo

antiberty-pytorch's People

Contributors

dohlee avatar mbestipa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

antiberty-pytorch's Issues

how load antiberty in jupyter?

Hello to everyone,
I am quite new in machine learning and not super familiar with python and probably my question could be naive.
Is the pre-trained model antiberty (or antiberta) available on hugginface? Ho can I load the model in jupyter notebook?
In the past months I have used protbert pre-trained model in jupyter with a command line like this
model = BertForSequenceClassification.from_pretrained('Rostlab/prot_bert_bfd', and than trained a simple model sequence classification.
I was looking for antibert as I have antibody sequences, but I could not find it into the Model section of HugginFace; I just found a Dataset in the corresponding. What should I do?
Thanks or any suggestion/feedback
Best
VB

issues in `AntiBERTyRunner.py`

Hey there,

I attempted to re-run the new v3.0.x of IGFold with openmm on my system last night. After updating and upgrading the packages, I tried to run the notebook, and I found the following error being thrown from the script AntiBERTy.py.

File "/xxx/yyy/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

To resolve this, I checked which devices the variables embeddings and attention_maps are attached and detached.

They both were created in GPU and only embeddings is detached from the GPU to the CPU. So, I made the following change:

  • Detached them to the CPU, and made both into a list.

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)
embeddings = embeddings.detach().cpu().tolist()

    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

It threw the following error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
TypeError: list indices must be integers or slices, not tuple

To understand the core problem, I wanted to understand embeddings and attention_maps. So,

`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    
    embeddings = embeddings.detach().cpu().tolist()
    
    for i, a in enumerate(attention_mask.detach().cpu().tolist()):
        embeddings[i] = embeddings[i][:, a == 1]

`

Details

embeddings: tensor([[[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 1.2959e-01, -2.3578e-01, -9.5074e-01,  ..., -3.4716e-01,
            3.9048e-01, -7.9039e-01],
          [-1.1861e-01, -8.5111e-01,  1.7778e-01,  ..., -6.4417e-01,
           -1.6268e-01, -7.4019e-01],
          ...,
          [ 1.4825e+00,  1.0562e+00, -5.5296e-01,  ...,  4.6048e-02,
           -5.8749e-01,  3.5935e-01],
          [ 1.1087e+00,  8.3452e-01, -4.6560e-01,  ..., -6.5979e-01,
            7.0711e-02,  1.3638e+00],
          [ 7.1583e-01,  8.4463e-01,  7.4550e-01,  ...,  5.5646e-01,
           -6.0864e-01,  1.2408e+00]],

         [[ 8.2428e-01, -6.0705e-01, -9.0634e-01,  ..., -4.5286e-02,
           -6.8834e-02,  4.4105e-01],
          [ 7.2001e-01,  6.3411e-01, -1.0107e+00,  ..., -4.3047e-01,
           -5.7251e-01, -6.7011e-01],
          [ 4.6859e-01, -8.5742e-01, -1.5053e-02,  ..., -2.8734e-01,
           -1.0233e+00, -3.6219e-01],
          ...,
          [ 1.0764e+00,  1.1695e+00, -6.8277e-01,  ...,  2.8122e-02,
           -9.8832e-01,  1.4659e-01],
          [ 8.8104e-01,  1.1147e+00, -7.1646e-01,  ..., -1.0783e-01,
           -7.9473e-01,  1.0538e+00],
          [ 6.3558e-01,  9.0190e-01,  4.0055e-01,  ...,  3.1800e-01,
           -1.0868e+00,  9.7025e-01]],

         [[ 9.6156e-01, -9.6647e-01, -1.4004e+00,  ..., -6.3557e-01,
            4.1958e-01, -1.8568e-01],
          [ 3.0844e-01,  1.0339e+00, -1.5486e+00,  ...,  2.1584e-01,
           -3.8619e-01, -8.9405e-01],
          [ 4.5382e-01, -3.8623e-01,  1.7961e-01,  ..., -1.4155e-01,
           -1.1880e+00, -5.4827e-01],
          ...,
          [ 9.9114e-01,  5.7983e-01, -2.9399e-01,  ..., -4.6010e-01,
           -6.7488e-01, -6.2466e-01],
          [ 7.5153e-01,  4.8691e-01, -5.4032e-01,  ...,  2.6127e-01,
           -1.0607e+00,  7.8277e-01],
          [ 8.5168e-01,  4.9293e-01, -2.6708e-01,  ...,  3.8526e-01,
           -1.1824e+00,  8.5203e-01]],

         ...,

         [[ 1.2814e+00, -4.3900e-01, -3.2785e-01,  ..., -1.2414e+00,
           -6.3775e-01, -1.3176e+00],
          [ 3.0157e-01,  1.6172e+00, -1.3343e+00,  ..., -1.2285e+00,
           -5.5167e-01, -1.8283e+00],
          [ 3.5919e-01, -2.6482e-01, -1.0645e+00,  ..., -4.3375e-02,
           -3.2065e-01, -9.8966e-01],
          ...,
          [ 1.8181e+00, -1.6646e-01, -1.2666e+00,  ...,  1.0637e+00,
            1.4646e+00, -1.6298e+00],
          [ 1.0763e+00, -5.1882e-01, -6.8510e-01,  ...,  1.3576e+00,
            1.2688e+00, -1.4657e+00],
          [ 1.7986e+00, -7.4009e-02, -1.2577e+00,  ...,  1.0660e+00,
            1.4812e+00, -1.4051e+00]],

         [[ 1.2025e+00, -5.5392e-01, -1.0193e+00,  ..., -8.1229e-01,
           -2.3811e-01, -4.7275e-01],
          [ 6.5538e-01,  1.1917e+00, -5.2697e-01,  ..., -8.7801e-01,
           -7.4126e-01, -1.9144e+00],
          [ 2.5875e-01, -7.9232e-01, -8.5029e-01,  ...,  6.4324e-02,
           -8.0997e-02, -1.9687e+00],
          ...,
          [ 1.4830e+00, -1.9244e-01, -6.8066e-01,  ...,  2.1269e-01,
            1.0873e+00, -1.3896e+00],
          [ 5.3997e-01, -1.4820e-01, -2.0483e-01,  ...,  7.3495e-01,
            8.6871e-01, -1.3526e+00],
          [ 1.6477e+00, -5.3092e-02, -7.1276e-01,  ...,  3.2879e-01,
            1.1778e+00, -9.6469e-01]],

         [[ 1.5494e+00, -9.5254e-01, -8.3588e-01,  ..., -4.2762e-01,
            6.2013e-01,  1.0120e-02],
          [ 4.4904e-02,  7.8505e-01, -1.0384e+00,  ..., -7.8334e-02,
           -1.7476e-01, -1.6311e+00],
          [ 1.7894e-01, -9.9010e-01, -1.1633e+00,  ...,  6.0122e-01,
           -1.0615e-01, -1.5358e+00],
          ...,
          [ 1.2771e+00, -1.8352e-01, -1.4466e+00,  ..., -6.2605e-01,
            1.2011e+00, -2.0856e+00],
          [ 5.6284e-01, -9.5801e-02, -1.1209e+00,  ..., -5.1828e-01,
            4.9442e-01, -1.5956e+00],
          [ 1.1071e+00,  3.0336e-01, -1.8048e+00,  ..., -3.8724e-01,
            1.1147e+00, -1.5361e+00]]],


        [[[ 2.8550e-01, -9.1305e-01, -3.2084e-01,  ..., -2.6094e-01,
           -1.1316e-01,  6.2532e-02],
          [ 5.0035e-01,  5.4549e-01,  3.4283e-01,  ..., -3.0739e-01,
           -4.9315e-01, -1.1373e+00],
          [-4.0275e-01,  2.1443e-02,  2.0123e-01,  ..., -2.4489e-01,
            8.3188e-01, -6.5645e-01],
          ...,
          [ 4.0514e-01, -3.2213e-01,  3.7994e-01,  ...,  1.2408e-01,
            6.3095e-01,  9.2037e-03],
          [ 1.9132e-01, -4.4131e-01,  4.2406e-01,  ..., -2.6266e-01,
            9.8391e-01,  5.5734e-01],
          [ 4.0278e-01, -4.9534e-02,  3.3810e-01,  ...,  1.4354e-01,
            8.4249e-01,  4.0723e-01]],

         [[ 6.2418e-02, -6.1317e-01, -1.5439e+00,  ..., -3.1803e-01,
           -2.0041e-01,  4.4618e-01],
          [-6.7039e-02,  1.2193e+00, -5.0822e-01,  ...,  3.5469e-01,
            2.6262e-02, -7.7125e-01],
          [-9.5805e-01,  1.4456e-01, -1.8127e-01,  ...,  3.6328e-01,
            1.4936e+00, -4.5747e-02],
          ...,
          [ 6.8287e-02,  8.2539e-01,  5.4192e-02,  ..., -1.1069e-01,
            6.6216e-01,  7.4946e-01],
          [-1.9581e-01,  6.8329e-01, -2.6928e-01,  ..., -7.0956e-01,
            7.8344e-01,  1.4804e+00],
          [-4.1462e-02,  8.8683e-01, -5.2905e-01,  ..., -2.5274e-01,
            7.1604e-01,  1.2256e+00]],

         [[ 3.5130e-01, -1.5874e+00, -1.7016e+00,  ...,  6.8850e-01,
           -5.8646e-01,  1.7784e-01],
          [ 1.1386e-01,  1.3657e+00, -8.2388e-01,  ...,  4.7490e-01,
            1.2626e+00, -3.1313e-01],
          [-1.1854e+00, -1.1600e-03, -7.3433e-01,  ...,  7.6139e-01,
            1.6375e+00,  1.8955e-01],
          ...,
          [-6.9969e-01,  1.1508e+00,  7.0558e-02,  ...,  4.2873e-01,
            5.6067e-01,  5.2250e-01],
          [-5.0788e-01,  6.6331e-01, -6.1032e-01,  ..., -2.3532e-01,
            8.2221e-01,  7.9204e-01],
          [-2.6820e-01,  8.5643e-01, -4.7090e-01,  ..., -2.8118e-01,
            6.5296e-01,  6.8785e-01]],

         ...,

         [[-9.0217e-02, -2.6741e-01, -1.0890e+00,  ...,  1.8798e+00,
           -3.2522e-03, -1.5653e-01],
          [-6.9740e-01,  1.4951e+00, -6.4886e-01,  ..., -1.3687e-01,
            1.4956e+00,  3.7487e-01],
          [-1.6580e-01,  1.1264e-01, -7.6442e-01,  ...,  4.3402e-01,
            1.9541e+00,  1.2029e+00],
          ...,
          [ 1.9953e-01,  2.6025e+00, -4.9651e-01,  ...,  5.0344e-01,
           -1.2114e-02,  3.9688e-01],
          [-1.0917e+00,  1.2115e+00,  6.2053e-01,  ...,  8.5435e-01,
           -4.5358e-02,  3.5120e-01],
          [ 6.1694e-01,  2.1130e+00, -1.1016e+00,  ...,  2.8187e-01,
            9.5419e-02, -3.5959e-01]],

         [[ 5.0400e-01, -5.3220e-01, -1.0173e+00,  ...,  2.1676e+00,
           -3.6843e-01, -1.8500e-01],
          [-2.1364e-01,  9.2027e-01, -2.5382e-01,  ...,  1.1757e-01,
            9.4363e-01,  6.0816e-01],
          [-1.0163e-01, -3.2413e-02, -7.2567e-01,  ...,  1.1070e+00,
            1.3306e+00,  1.0462e+00],
          ...,
          [ 3.0022e-01,  2.6991e+00, -4.7573e-01,  ..., -1.0428e-01,
           -7.8721e-02,  1.1695e+00],
          [-1.0961e+00,  6.7808e-01, -3.0792e-01,  ...,  7.1660e-01,
           -2.0900e-01,  4.5738e-01],
          [ 8.4948e-01,  1.9340e+00, -1.1624e+00,  ..., -2.2008e-01,
            4.5761e-01,  9.6474e-01]],

         [[-2.9150e-01,  4.8298e-01, -3.7572e-01,  ...,  2.4827e+00,
           -1.9686e-01,  2.9108e-01],
          [-4.5003e-01,  4.0321e-01, -1.0218e+00,  ..., -1.9378e-01,
            5.3391e-01,  3.8499e-01],
          [-7.7064e-02, -5.0206e-01, -1.3377e+00,  ...,  9.1953e-01,
            5.2488e-01,  1.2372e-01],
          ...,
          [ 7.2962e-01,  1.8133e+00,  2.9414e-01,  ...,  7.3038e-01,
           -2.0271e-01,  2.1481e+00],
          [-7.7066e-01, -1.0586e-01, -9.3787e-02,  ...,  1.0239e+00,
           -2.1658e-01,  9.3203e-01],
          [ 1.0556e+00,  9.7592e-01, -1.2148e+00,  ..., -4.7689e-02,
           -1.4709e-02,  2.9145e-01]]]], device='cuda:0') length:torch.Size([2, 9, 120, 512])
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) length:torch.Size([2, 120])

I made this change:
`
# gather embeddings
embeddings = outputs.hidden_states
embeddings = torch.stack(embeddings, dim=1)

    print(f"embeddings: {embeddings} length:{embeddings.shape}")
    print(f"attention_mask: {attention_mask.detach().cpu()} length:{attention_mask.shape}")
    embeddings = embeddings.detach().cpu()
    
    for i, a in enumerate(attention_mask.detach().cpu()):
        embeddings[i] = embeddings[i][:, a == 1]

`

So, finally, I tried to replace them as tensors and tried to replace, but it obviously threw tensor dimensions mismatch error:

File "/home/randd/anaconda3/envs/igfold2/lib/python3.10/site-packages/antiberty/AntiBERTyRunner.py", line 88, in embed
    embeddings[i] = embeddings[i][:, a == 1]
RuntimeError: The expanded size of the tensor (120) must match the existing size (109) at non-singleton dimension 1.  Target sizes: [9, 120, 512].  Tensor sizes: [9, 109, 512]

Because embeddings size is: [2, 9, 120, 512].
Whereas attention_mask size is: [2, 120].

What is the end goal of the following snippet? Why does this throw an error? Please help me resolv this.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.