Code Monkey home page Code Monkey logo

Comments (14)

yong-wang avatar yong-wang commented on June 16, 2024

Thanks for reporting the problem. Seems it's caused by a combination of a few factors.

Although the SIFT-1M dataset is in float32 format, the values are actually integers in the range [0, 218].
So when computed with half, many of the distance values are overflowed to inf.

For example, for one query, I counted the number of inf in the input to radix select kernel:
internalDistanceDtype=half, smemLutDtype=half: 25733
internalDistanceDtype=float, smemLutDtype=half: 16434
internalDistanceDtype=float, smemLutDtype=float: 16408
(I don't know the implementation details, should there be overflow in the float case?)

The total number of input values are 25800.
So in the half case, there are 25800-25733=67 values that are not inf, which means top-129 results will contain inf, or the 129th value is inf.

The last step of radix select is filtering the values less or equal to the 129th value. Because the 129th value is inf, all the inf values are candidates for this step, so there are lots of candidates. In current setting, only one thread block is used in this step, hence the radix select becomes slow.
It is fine for the float case, because the 129th value is not inf.

One possible (but not good) solution is:
The latest radix select implementation has a bool parameter fused_last_filter to control the implementation of this filtering step. When fused_last_filter=true, only one thread block is used as described above. When fused_last_filter=false, a standalone kernel is launched, and more thread blocks are used so it's still fast even when the number of candidates is large. I tried this approach, and it was indeed helpful.
However, if distance values are not distorted as such, the number of candidates in the filtering step should be small, and using fused_last_filter=true (as in the current setting) reduces one kernel call and will be slightly faster. So I think the current setting is reasonable.

I feel the root problem is that SIFT dataset is in float32 while it should be uint8.

from raft.

achirkin avatar achirkin commented on June 16, 2024

Wow, thanks for the insight!

Perhaps, we can try some form of normalizing on IVF-PQ side to remove unwanted infinities.

However, I've just come to realize that many of the infs are by design: when the non-fused version of the kernel is used, it fills all distances between the actual data and max_sampls bound with infinities. Hence I think, it's important edge use case for select_k to handle infinities. Can we maybe do some sort of pre-filtering on the first pass of radix sort?.. I would switch to warp-sort, but it does not support k > 256, so we have to stick to radix sort.

from raft.

achirkin avatar achirkin commented on June 16, 2024

Also if the filtering functionality is used in ANN methods, it will output even more infinites in place of filtered data. So we definitely need to make select_k handle infinities fast.

from raft.

achirkin avatar achirkin commented on June 16, 2024

I've come up with this solution: #1742 but it seems to reduce the recall even further on the sift-128 data (but QPS is great! :-D). Also I feel like there must be a better solution (at least, limit filtering to the zero-th pass only).

With this solution, I'm not really sure, what happens if there are less non-infs than k?

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

Since it's expected that there will be lots of infs, I think setting fused_last_filter=false for radix select is the easy and suitable way. This parameter is actually intended to be used in such scenario where we know in advance that data can be unevenly distributed. The cost is an extra kernel launching, which is usually quite small and should not be significant for IVF-PQ.

The alternative way is changing radix select to treat inf as a special case. However, it will make radix select less general.
Currently, radix select treat inf as a normal value: it's just the largest float value. It's legal to have inf in the input and inf can be in the top-k result.
If we treat inf specially, like keeping it out of the result, then it might surprise other client code.

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

In my tests, by setting fused_last_filter=false, one radix_kernel which used 350us is changed to a 8us radix_kernel + a 11 us last_filter_kernel.

from raft.

achirkin avatar achirkin commented on June 16, 2024

Thanks for checking, this would be an acceptable solution.
However, I think, it's still worth keeping this optimization. It is especially useful for reducing the latency with small inputs. Handling infinities specially indeed adds a bit more edge cases, but with the new update to my PR it doesn't lead to the loss of generality: we filter out min_bound/max_bound values, but add them to the end of the output if necessary; we know for sure nothing can be smaller/larger, so it does not disrupt the radix scanning logic.

Have a look, I fixed the issues with my workaround and got a bit better performance with filtering infinities: the overall time is 14.2 vs 19 us. This is for 200 probes as the other tests above, I suppose for smaller number of probes the difference is bigger.
image

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

How about the change of QPS averaged over all queries?
I have no doubt that the new PR will improve QPS in this case. But I'm still not sure if it's worth adding such special handling.

Treating inf in a special way adds overhead for common cases. For example, if the input contains a few inf values but these values don't belong to the final results, then we pay the overhead of writing these inf values to the output during pass 0, and they only get overwritten later.
So it's possible that the QPS for data type float becomes lower.

The slowdown case for data type half may not be an important case. It happens when inf is the k-th value. When this happens, there will be results with distances inf. Such results are chosen somewhat randomly by radix select, and hence the recall will be decreased.
In order to optimize such case, both special treatment of inf and setting fused_last_filter=false make the normal case, like data type float, a little slower. Setting fused_last_filter=false is slower than special treatment of inf, while special treatment of inf will slow down more applications of radix select.

So the question is that is it worth optimizing such edge case, which likely has low recall, at the cost of slowing down normal cases, even radix select itself?
I prefer to regard radix select as a general method that can be used in many applications, not only ANN. Making it slower to accelerate a ANN edge case doesn't seem a good trade-off.

The more I think about it the more I feel that we don't need to do anything for this case.

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

However, I've just come to realize that many of the infs are by design: when the non-fused version of the kernel is used, it fills all distances between the actual data and max_sampls bound with infinities.

To my understanding, these infs are more like paddings, and they usually won't make the number of legal results smaller than k. It's the half overflow that produces too many infs, hence the slowdown.

from raft.

achirkin avatar achirkin commented on June 16, 2024

I understand your concern. Adding a special case for treating a input single value is a logical burden for an algorithm. However, I don't think it adds any significant overhead in the common case.

Let's have a look at histogram+filter changes: the only thing it really adds to the loops is a single comparison value == bound. Then, there's a bit of extra work when i < k (which is normally just a single iteration). It does write to the output, but I think more costly is the atomicAdd operation there. It leads to congestion when there are multiple bound values within the first k values of the input. Beyond the first k values it does nothing; moreover, it skips the atomicAdd on the histogram for the bound values, so it may perform even faster when there are many bound values at the end of the input (IVF use case). Hence I don't think this should slow down the algorithm for any case.

That being said, I did more benchmarks to verify if there are any slowdowns. Apparently, the extra code and variables increases the register pressure, especially for the single-block version with types double+uint64_t. I was able to fight it back by adding __noinline__ to the single block filter+histogram function. However, the whole change inevitably leads to slight differences in the kernels and their performance.

Here's the results:
https://docs.google.com/spreadsheets/d/12RAA3IYUB3ou8DusCvpCOTHPooMNWTHmrzInGcjN1XI/edit?usp=sharing
The maximum slowdown is 12.8%
The maximum speedup is x200
The average time change across all tests: -17.9%
The average time change across tests without edge cases (> -50%): -0.04%

So the conclusion is the time with/out the PR is generally the same, but the edge cases are much faster (i.e. not abnormally slow).

from raft.

achirkin avatar achirkin commented on June 16, 2024

Regarding the infs due to paddings:

In the near future we're going to have many downstream projects using ANN pre-filtering (e.g. for "deleting" values without re-training). This will lead to more infinities in the valid range of the input. So I believe this is a very important use case (but also overflow of half is not so rare apparently).

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

I run some end-to-end benchmarks with float/half data type. The results are:
image

The "baseline" is for the latest 23.10, and "filter-out-inf" is for PR1742. "refine4" means refine factor is 4.

The two curves in the top left corner are for half. Yes, the QPS improves quite a lot with the new PR.
However, the recall is quite low (~0.6) when data type is half. Even with refinement (bottom left curves), the recall is still low. It's because ANN could not find k vectors with non-inf distances.
(The shape of these half curves, like the blue one, is strange. It's because larger nprobe results in lower recall. Such wired behavior also implies wired results are returned.)

In contrast, the recall is much higher for data type float (top right curves). And the baseline is slightly faster (but seems reproducible). Also, refinement improves the recall to nearly 1.0 (bottom right curves).

So, is the slowdown for data type half an important case that we need to worry about? I don't think so, because the recall is quite low and type float should be used.
The reason of low recall for half is that there are too many infs that ANN could not return k items with non-inf distances. The same reason as why radix select becomes slow.

I would like to emphasize that the slowdown of radix select is not because there are many infs, but because there are so many infs that the number of non-inf values is smaller than k, or equivalently, the k-th value is inf.
When radix select becomes slow, some other serious problem has already occurred, and the recall will be low. So it's not the case we should worry about for radix select.

The same reasoning applies to ANN pre-filtering. If so many values are deleted that ANN could not return k items with valid distances, it means too many values have been deleted, which is a more serious problem.

So, I still think we don't need to fix this case.

The logical burden of special treating of inf is indeed the thing concerns me most. The changes such as

-    if (prev < k && cur >= k) {
+    if (prev < k && (cur >= k || i + 1 == num_buckets)) {

-    if (counter.len == counter.k || pass == num_passes - 1) {
+    if (counter.len <= counter.k || pass == num_passes - 1) {

make me nervous. There are already too many if branches in the code (my bad). The changes make some nice post-conditions not hold anymore. If we extend the code in the future, there will be more edge cases to consider.

from raft.

achirkin avatar achirkin commented on June 16, 2024

Do I understand correctly, that your main concern is the more complicated logic because of the extra if branches (and two updates branches)? I believe, thanks to your thorough review we have sorted out possible new bugs, and I did account for the broken post-conditions.

I think the proposed fix adds the value in that it fixes the x10 slowdown in some edge cases with little to no cost to any of the other cases. Aside from the zero-th pass it doesn't really complicate the logic that much.

We could argue whether IVF-PQ should be fixed to not produce as many infinities etc, but I don't think it's relevant here. If the number of infinities in the input of radix-select is larger than n_inputs - k, the current implementation significantly downgrades in performance with the last-scan optimization (which otherwise brings a lot of improvement); the PR fixes that. Whether or not such large of infinities hurts the recall of IVF-PQ depends on how much of the non-infinity values left and should be addressed in IVF-PQ.

I see your point of making too much workarounds to radix-select, but I think the improvement is still worth it. I'd suggest we ask for a third opinion on this. CC @tfeher @cjnolet ?

from raft.

yong-wang avatar yong-wang commented on June 16, 2024

For the code logic, I think the code change itself for zero-th pass is fine because the intention is clear. I'm more worried about the effect. For the current code, couter.len will be equal to counter.k eventually, but after the change, it's possible that counter.len < counter.k. So there are more cases to consider if we make extensions in the future.

For radix-select performance, the current code handles inf reasonably well. From the Google spreadsheet, even when 90% of the data is inf there is no improvement with PR1742 for float-uint32 and double-uint32_t (for double-uint64_t, current code is much slower; need a closer look), Admittedly, when inf ratio is 99.9-100%, the PR is much faster.

From the figure I posted earlier, there are two problems for the "baseline.half" curve: both its QPS and recall are low. These two problems are closely related. Actually they are the result of the same root cause: too many overflowed half distances.

The PR fixes the low QPS as shown as the pink curve, but the low recall stays the same. Without the improvement of the recall, the improvement of QPS doesn't matter much in this case, because the float curve should be favored.
On the other hand, if we fix the low recall (e.g. there will be more than k non-infs), the low QPS will be fixed automatically.

Overall, the PR improves the low QPS of half quite a lot, while the float case becomes slightly slower. (I don't worry about it because it's quite small. But it does exist and is not the measurement noise in my tests). Considering the code complexity it adds, I don't think it's a good trade-off.

from raft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.