Code Monkey home page Code Monkey logo

saf's People

Contributors

reallsp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

saf's Issues

提供的代码不完整

在运行您的代码进行test时,由于metric.py中的 compute_topk 函数明显不正确,程序无法继续执行。

Traceback (most recent call last):
  File "./train.py", line 240, in <module>
    main(args)
  File "./train.py", line 206, in main
    ) = test(test_loader, network, args, unique_image, epoch)
  File "/data1/jd/projects/reid/SAF/test.py", line 64, in test
    result, score = compute_topk(
  File "/data1/jd/projects/reid/SAF/utils/metric.py", line 348, in compute_topk
    result.extend(topk(score, target_gallery, target_query, k, dim=1))
  File "/data1/jd/projects/reid/SAF/utils/metric.py", line 363, in topk
    correct = pred_labels.eq(target_query.view(1, -1).expand_as(pred_labels))
RuntimeError: The expanded size of the tensor (11) must match the existing size (3074) at non-singleton dimension 1.  Target sizes: [10, 11].  Tensor sizes: [1, 3074]

您的原始代码:

def compute_topk(
    query, gallery, target_query, target_gallery, k=[1, 10], reverse=False
):
    result = []
    query0 = query[:,0] / query[:,0].norm(dim=1, keepdim=True)
    gallery0 = gallery[:,0] / gallery[:,0].norm(dim=1, keepdim=True)
    sim_cosine = torch.matmul(query0, gallery0.t())
    score = torch.zeros((sim_cosine.shape))
    for i in range(query.size(1)):
        query0 = query[:,i] / query[:,i].norm(dim=1, keepdim=True)
        gallery0 = gallery[:,i] / gallery[:,i].norm(dim=1, keepdim=True)
        sim_cosine = torch.matmul(query0, gallery0.t())
        score+=sim_cosine
    result.extend(topk(score, target_gallery, target_query, k, dim=1))
    if reverse:
        result.extend(topk(score, target_query, target_gallery, k, dim=0))
    return result, score



def topk(sim, target_gallery, target_query, k=[1, 10], dim=1):
    result = []
    maxk = max(k)
    size_total = len(target_query)
    _, pred_index = sim.topk(maxk, dim, True, True)  # 得到相似度最大的前k个
    pred_labels = target_gallery[pred_index]
    if dim == 1:
        pred_labels = pred_labels.t()
    correct = pred_labels.eq(target_query.view(1, -1).expand_as(pred_labels))

    for topk in k:
        correct_k = torch.sum(correct[:topk], dim=0)
        correct_k = torch.sum(correct_k > 0).float()
        result.append(correct_k * 100 / size_total)
    return result

这样的代码无法进行11个part之间的query --> gallery 的相似度计算并得到top-k,代码中这一行

score+=sim_cosine

只会得到一个11x11的相似度矩阵,无法继续进行运算,我猜想可能是您将早期的测试代码上传到了github,希望作者能够将实际可以运行的代码放出

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.