Code Monkey home page Code Monkey logo

ordinal-log-loss's People

Contributors

castafra avatar charles-glanceable avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

beomi kasakh

ordinal-log-loss's Issues

"requires_grad=True" in loss function

Hello!
I've read the articles on ordinal log loss and very impressed with the idea of distances.
I tried to use OLL in my study, and I found some code which is not familiar to me.

def compute_loss(self, model, inputs, return_outputs=False):
    num_classes = model.module.num_labels
    dist_matrix = model.module.dist_matrix
    labels = inputs["labels"]
    outputs = model(**inputs)
    logits = outputs.logits
    probas = F.softmax(logits,dim=1)
    true_labels = [num_classes*[labels[k].item()] for k in range(len(labels))]
    label_ids = len(labels)*[[k for k in range(num_classes)]]
    distances = [[float(dist_matrix[true_labels[j][i]][label_ids[j][i]]) for i in range(num_classes)] for j in range(len(labels))]
    distances_tensor = torch.tensor(distances,device='cuda:0', requires_grad=True)
    err = -torch.log(1-probas)*abs(distances_tensor)**2
    loss = torch.sum(err,axis=1).mean()
    return (loss, outputs) if return_outputs else loss

In this code, the following line:

distances_tensor = torch.tensor(distances,device='cuda:0', requires_grad=True)

According to my best knowledge, distances_tensor itself did not change but just "given" by label information. So I think distances_tensor does not need to have "requires_grad=True". Even after removing those code, loss.backward() seems to work properly.

Is there any reason distances_tensor have "requires_grad=True"?

Thanks in advance.

Could you give an example of the parameters in loss function?

Hi,

Thanks for the code you shared.
I would like to copy the code to my project, but I encountered some mistakes.
Could you give an example of the parameters below? (such as logits, probas, true labels, dist_mat, distances)

Thanks in advance.

class OLL1Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        num_classes = model.module.num_labels
        dist_matrix = model.module.dist_matrix
        labels = inputs["labels"]
        outputs = model(**inputs)
        logits = outputs.logits
        probas = F.softmax(logits,dim=1)
        true_labels = [num_classes*[labels[k].item()] for k in range(len(labels))]
        label_ids = len(labels)*[[k for k in range(num_classes)]]
        distances = [[float(dist_matrix[true_labels[j][i]][label_ids[j][i]]) for i in range(num_classes)] for j in range(len(labels))]
        distances_tensor = torch.tensor(distances,device='cuda:0', requires_grad=True)
        err = -torch.log(1-probas)*distances_tensor
        loss = torch.sum(err,axis=1).mean()
        return (loss, outputs) if return_outputs else loss

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.