csarofeen / pytorch Goto Github PK
View Code? Open in Web Editor NEWThis project forked from pytorch/pytorch
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home Page: http://pytorch.org
License: Other
This project forked from pytorch/pytorch
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home Page: http://pytorch.org
License: Other
I found a situation where a schedule was allowed to generate a kernel were TIDx
was allowed to be bound to two different values. An error should have been generated. The question is how to check for this?
The scenario where I made this mistake was when I did an rFactor
on a reduction dimension split and didn't apply the corresponding split to an operation found after the reduction.
IR View of Operations
T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ] = reduction( T0[ iS{i1}, iS{i3} ], op = add, initial value = float(0) )
T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ] = reduction( T5[ iblockIdx.x{gridDim.x}, rS{( ceilDiv(i3, 32) )}rf, ithreadIdx.x{32}rf ], op = add, initial value = float(0) )
T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]
= T2[ iblockIdx.x{gridDim.x}, rthreadIdx.x{32} ];
T4[ iblockIdx.x{gridDim.x}, ithreadIdx.x{blockDim.x} ]
= T3[ iblockIdx.x{gridDim.x}, bthreadIdx.x{1} ]
+ T1[ iS{i5}, iS{i7} ];
Kernel Generated
__device__ void reduction_add_float(float& a, const float b) {
a = a + b;
}
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T4){
__shared__ float shared_mem[1024];
float T3[1];
float T2[1];
T2[ 0 ]
= float(0);
float T5[1];
if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T5[ 0 ]
= float(0);
}
for(size_t i30 = 0; i30 < ( ceilDiv(T0.size[1], 32) ); ++i30 ) {
if ( ( ( ( i30 * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T5[ 0 ]
= T5[ 0 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( i30 * 32 ) + threadIdx.x ) * T0.stride[1] ) ];
}
}
blockReduce< true, false, false > ( T2[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
if ( ( threadIdx.x == 0 ) ) {
T3[ 0 ]
= T2[ 0 ];
}
T4[ ( blockIdx.x * T4.stride[0] ) + ( threadIdx.x * T4.stride[1] ) ]
= T3[ 0 ]
+ T1[ ( blockIdx.x * T1.stride[0] ) + ( threadIdx.x * T1.stride[1] ) ];
}
Test
void testGPU_FusionThreadBindingError() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
// TV0 original [ 32 , 128 ]
// TV5 [ 32, 32 ] <- rFactor TV2[ 32, 128] -> [32, 4, 32] (bound to tidx(-1))
// TV2 [ 32 ] Final Reduce (bound to tidx(-1))
// TV3 [ 32, 128 ] Broadcast (bound to tidx(-1))
// TV4 [ 32, 128 ] Add (bound to tidx(-1))
TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
TensorView* tv3 = broadcast(tv2, {false, true});
TensorView* tv4 = add(tv3, tv1);
tv2->split(-1, 32);
TensorView* tv5 = tv2->rFactor({-2});
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(0)->parallelize(ParallelType::BIDx);
tv5->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv5->axis(-1)->parallelize(ParallelType::TIDx);
fusion.addOutput(tv4);
fusion.printMath();
GPULower gpulw(&fusion);
gpulw.printKernel(std::cout);
prog.device_ = 0;
prog.grid(32);
prog.block(32);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({32, 128}, options);
at::Tensor t1 = at::randn({32, 128}, options);
at::Tensor cg_output = at::empty({32, 128}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
}
pytorch/torch/csrc/jit/codegen/cuda/fusion.cpp
Lines 194 to 199 in b330974
The ranges used as inputs to std::set_unions() must be sorted:
https://en.cppreference.com/w/cpp/algorithm/set_union
Currently, the values are automatically "named" using a simple auto-incremented numeric id.
This means that using the C++ APIs allow meaningful names for C++ variables, but they will not match the IR names. At best, code ends up looking like this:
auto input_tv0 = makeDummyTensor(1);
auto exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0);
auto sum_exp_tv2 = sum(exp_tv1, {-1});
auto bcast_sum_tv3 = broadcast(sum_exp_tv2, {true});
One solution would be to allow setting explicit names for values. This would be optional (if not set explicitly we'd default to an auto-naming scheme like we do today).
An hypothetical solution which is backward compatible with the current API usage could look like this:
auto input = makeDummyTensor(1)->named("tv0");
auto exp = unaryOp(UnaryOpType::Exp, input_tv0)->named("tv1");
auto sum_exp = sum(exp_tv1, {-1})->named("tv2");
auto bcast_sum = broadcast(sum_exp_tv2, {true})->named("tv3");
Remove all the references to TensorContiguity
I try to modify the cpp test FusionExprEvalBasic
to have code lowering prior to use expression evaluation.
But I'm getting assertion error doing so. Looking at the comments in the example, I understand it's probably a case that we do not yet support.
Trying to use ExpressionEvaluator to infer launch configuration.
This modified example should work.
// Evaluate expressions in a simple IR
void testGPU_FusionExprEvalBasic() {
Fusion fusion;
FusionGuard fg(&fusion);
// Create a non-trivial IR
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
tv3->split(0, 4);
tv0->computeAt(tv3, 1);
tv1->computeAt(tv3, 1);
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(1)->parallelize(ParallelType::Unroll);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
// This appears to be causing issue;
GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::cout << cdg.str() << std::endl;
// 1. Create an evaluation context
EvaluationContext eval_context(&fusion);
// 2. Bind values
//
// IMPORTANT:
// a. The bindings are only as stable as the Vals are in the fusion graph
// b. You must use the original (rootDomain) extents
// (ex. `tv0->getRootDomain()[0]->extent()`
// instead of `tv0->axis(0)->extent()`)
eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);
eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);
// 3. Evaluate and check result values
TORCH_CHECK(tv2->domain()->nDims() == 3);
checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);
TORCH_CHECK(tv3->domain()->nDims() == 3);
checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);
const auto bid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);
std::cout << "bid x value " << bid_x_val.value() << std::endl;
const auto tid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);
std::cout << "tid x value " << tid_x_val.value() << std::endl;
}
pytorch/torch/csrc/jit/codegen/cuda/transform_replay.cpp
Lines 113 to 116 in 172bcdb
If I add a check there I see JitTest.GPU_FusionSimplePWise_CUDA failing:
[ RUN ] JitTest.GPU_FusionSimplePWise_CUDA
unknown file: error: C++ exception with description "axis + 1 < axis_map.size() INTERNAL ASSERT FAILED at ..\torch\csrc\jit\codegen\cuda\transform_replay.cpp:114, please report a bug to PyTorch. (replay at ..\torch\csrc\jit\codegen\cuda\transform_replay.cpp:114)
00007FF95523E77200007FF95523E730 c10.dll!std::_Invoker_functor::_Call<<lambda_caff7cead0c0e49cc41a504cb708e59a> &> [C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.24.28314\include\type_traits @ 1579]
00007FF95523EBC200007FF95523EB80 c10.dll!std::invoke<<lambda_caff7cead0c0e49cc41a504cb708e59a> &> [C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.24.28314\include\type_traits @ 1579]
00007FF95523E7D200007FF95523E790 c10.dll!std::_Invoker_ret<std::basic_string<char,std::char_traits,std::allocator >,0>::_Call<<lambda_caff7cead0c0e49cc41a504cb708e59a> &> [C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.24.
00007FF95523F47100007FF95523F430 c10.dll!std::_Func_impl_no_alloc<<lambda_caff7cead0c0e49cc41a504cb708e59a>,std::basic_string<char,std::char_traits,std::allocator > >::_Do_call [C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC
00007FF95523F24500007FF95523F1E0 c10.dll!std::_Func_class<std::basic_string<char,std::char_traits,std::allocator > >::operator() [C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.24.28314\include\functional @ 969]
00007FF95523D49200007FF95523D440 c10.dll!c10::Error::Error [C:\Users\lemo\work\external\pytorch\pytorch\c10\util\Logging.cpp @ 67]
00007FF909AAC9CC00007FF909AAC280 torch_cuda.dll!torch::jit::fuser::TransformReplay::replay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_replay.cpp @ 114]
00007FF909AAA6FA00007FF909AAA4E0 torch_cuda.dll!torch::jit::fuser::TransformIter::replay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_iter.cpp @ 105]
00007FF909AAA94C00007FF909AAA890 torch_cuda.dll!torch::jit::fuser::TransformIter::runReplay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_iter.cpp @ 115]
00007FF909AADA0000007FF909AAD4C0 torch_cuda.dll!torch::jit::fuser::TransformReplay::runReplay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_replay.cpp @ 308]
00007FF909AADC7000007FF909AADC10 torch_cuda.dll!torch::jit::fuser::TransformReplay::runReplay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_replay.cpp @ 323]
00007FF909AADCFB00007FF909AADCA0 torch_cuda.dll!torch::jit::fuser::TransformReplay::replay [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\transform_replay.cpp @ 335]
00007FF909AA765700007FF909AA6E80 torch_cuda.dll!torch::jit::fuser::TensorView::computeAt [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\tensor_view.cpp @ 178]
00007FF909AA73C900007FF909AA6E80 torch_cuda.dll!torch::jit::fuser::TensorView::computeAt [C:\Users\lemo\work\external\pytorch\pytorch\torch\csrc\jit\codegen\cuda\tensor_view.cpp @ 139]
00007FF7339436A100007FF733943090 test_jit.exe!torch::jit::testGPU_FusionSimplePWise [C:\Users\lemo\work\external\pytorch\pytorch\test\cpp\jit\test_gpu.cpp @ 979]
00007FF7337F89D300007FF7337F89B0 test_jit.exe!torch::jit::JitTest_GPU_FusionSimplePWise_CUDA_Test::TestBody [C:\Users\lemo\work\external\pytorch\pytorch\test\cpp\jit\gtest.cpp @ 19]
00007FF733ADE94100007FF733ADE910 test_jit.exe!testing::internal::HandleSehExceptionsInMethodIfSupportedtesting::Test,void [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2428]
00007FF733ADE64100007FF733ADE5D0 test_jit.exe!testing::internal::HandleExceptionsInMethodIfSupportedtesting::Test,void [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2479]
00007FF733AC2F8300007FF733AC2ED0 test_jit.exe!testing::test::Run [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2524]
00007FF733AC3B1E00007FF733AC3A40 test_jit.exe!testing::TestInfo::Run [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2697]
00007FF733AC42A600007FF733AC41B0 test_jit.exe!testing::TestCase::Run [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2812]
00007FF733ACB51E00007FF733ACB1C0 test_jit.exe!testing::internal::UnitTestImpl::RunAllTests [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 5178]
00007FF733ADEA8100007FF733ADEA50 test_jit.exe!testing::internal::HandleSehExceptionsInMethodIfSupportedtesting::internal::UnitTestImpl,bool [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2428]
00007FF733ADE8A100007FF733ADE830 test_jit.exe!testing::internal::HandleExceptionsInMethodIfSupportedtesting::internal::UnitTestImpl,bool [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 2479]
00007FF733AC499300007FF733AC4870 test_jit.exe!testing::UnitTest::Run [C:\Users\lemo\work\external\pytorch\pytorch\third_party\googletest\googletest\src\gtest.cc @ 4786]
This useful debugging tool was lost in a merge. This issue is to make sure we don't forget to add it back.
pytorch/torch/csrc/jit/codegen/cuda/lower_utils.cpp
Lines 210 to 241 in b330974
For example:
void handle(ForLoop* fl) final {
for (Expr* expr : fl->body().exprs()) {
auto it = replacement_map_.find(expr);
if (it == replacement_map_.end()) {
handle(expr);
continue;
}
// ----------------------- this mutates fl->body() while iterating on it -------------------
fl->body().insert_before(expr, replacement_map_[expr]);
fl->body().erase(expr);
}
}
Terminology
Dependency chain:
A dependency chain from TV0->TV5
is a path through the arithmetic operations from TV0 to TV5
. For example if we had:
TV1 = TV0 * 2.0
TV3 = TV1 + TV2
TV4 = TV3 * TV0
TV5 = TV3 - TV4
then
TV0 -> TV1 -> TV3 -> TV4 -> TV5
and
TV0 -> TV4 -> TV5
would both be dependency chains of TV0->TV5
Use chains: A use chain are all paths from a given tensor view to all outputs. The chains will be connected through dependencies, and consist of all tensors that depend in some way on a given tensor. If we refer to the Dependency chain example, all use chains of TV3
are:
TV3 -> TV4 -> TV5
TV3 -> TV5
Producer/Consumer:
A producer-consumer relationship is one where there is a valid, non-empty dependency chain from a TensorView
(producer), to another TensorView
(consumer). A direct producer-consumer relationship would be if a consumer has producer as the input to its origin expression (the expression that generates consumer).
Common consumer:
A common consumer of a producer is a TensorView
that exists in the intersection of all use chains of a given producer. In the Use chains example, TV5
is a common consumer of TV3
, but TV4
is not.
replayPasC and replayCasP:
These are the two major replay functions used in computeAt transform replay. replayPasC
stands for replay producer as consumer, and replayCasP
stands for replay consumer as producer. Both functions have a position argument, which means we need to replay so that axes < pos
match between the two TensorView
s. The reason there are two functions are because to do the replay we need to create a map of the root domains of the TensorViews and how this mapping occurs changes based on which TensorView is the consumer/producer and what we're trying to replay. As we run these two replays we also set the computeAt
of the producer
. replayPasC
and replayCasP
are wrapped in functions called computeAt_impl
and forwardComputeAt_impl
respectively in the code base.
Compute at current implementation and challenges
The computeAt
pass is having some challenges getting correct structure. Of note is: #110
The syntax of computeAt
is producer->computeAt(consumer, pos)
producer and consumer here do not need to have a direct producer-consumer relationship.
Which means that we should generate producer
at pos
within the loop nest structure of consumer
. Meaning what we'd like is a structure like:
for consumer->axis(0:pos)
for producer->axis(pos:-1)
producer = ...
for consumer->axis(pos:-1)
consumer = ... producer ...
Where (pos:-1)
indicates iterating over the axes/domains starting from position, going to the last axis of the tensor. The general process for doing this is to transform producer so its axes (0:pos)
are equivalent to consumers axes (0:pos)
. This transformation is done with the minimal required transformations, for example any axes that don't need to be transformed to generate producer axes (0:pos)
are unchanged from producers previous state.
The general challenge here is the multiple consumer issue. If we have a second consumer of producer that is "unrelated" to the first consumer, that second consumer needs to follow the same pattern as the first, because we would need to generate a structure that looks like:
for consumer1->axis(0:pos)
for producer->axis(pos:-1)
producer = ...
for consumer1->axis(pos:-1)
consumer1 = ... producer ...
for consumer2->axis(pos:-1)
consumer2 = ... producer ...
any other placement of consumer2
would not be valid due to the transformation on producer
unless we generated producer
multiple times which is not currently supported (this could be an optional flag on computeAt
, or a transformation on producer). This means that we potentially need to modify tensors that are outside the dependency chains from producer->consumer
, meaning we need to be able to propagate the transformation dictated by the computeAt
call.
Current implementation:
consumer
to common consumer
. As walking this dependency chain run replayCasP
so that common consumer
will be transformed under pos
like consumer.common consumer
matches the computeAt
on consumer
, follow all dependency chains backwards from producer to common consumer, and run replayPasC. Now all TensorViews from producer to common consumer should match.replayPasC
.replayCasP
What's wrong:
As we iterate in steps 2.a., 2.b., 3.a., and 3.b. the computeAt
position will change depending on if we're running replayCasP
or replayPasC
. The position used in replayPasC
is relative to the consumer, and the position in replayCasP
is relative to the producer, which is correct.
Consider:
T2[i0, r1, i2] = T1[i0, i1, i2] ...
T3[i0, i2] = T2[i0, r1, i2] ...
and we start at T2
, position 2, and we iterate forward with replayCasP
. We will end up with the computeAt
settings:
T1->computeAt(T2, 2)
T2->computeAt(T3, 1)
The second because there was a reduction within the computeAt, so we effectively lose an axis as we go over it.
Also consider:
T2[i0, b1, i2] = T1[i0, i2] ...
T3[i0, i1, i2] = T2[i0, b1, i2] ...
and we start at T3
, position 2, and we iterate backward with replayPasC
. We will end up with the computeAt
settings:
T1->computeAt(T2, 1)
T2->computeAt(T3, 2)
If this doesn't seem like a problem with our current approach yet, let's consider #110
We effectively have:
T1[i0, i1] = T0[i0, i1]
T2[i0, r1] = T1[i0, i1]
T3[i0, r1] = T1[i0, i1]
T4[i0] = T2[i0, r1], T3[i0, r1]
and we call:
t1->computeAt(t2, 2);
Based on our current procedure, we see that T1 has multiple uses, we look for its common consumer which is T4. We forward propagate from T2 to T4 based on a dependency chain and get:
T2->computeAt(T4, 1)
Then we go backwards through all dep chains from producer->common consumer:
T3->computeAt(T4, 1)
T2->computeAt(T4, 1)
T1->computeAt(T2, 1)
T1->computeAt(T3, 1)
Therefore there is no way to inline the consumption of T1 into both T2 and T3. Even though, in theory T2, and T3 can be inlined with eachother.
This procedure would work if instead of going backward through all dep chains from producer to common consumer we went forward calling replayCasP
. However, if instead of reductions we had broadcasts, it would not work for broadcasting. In the current approach, either one or the other works in this case, and right now broadcast does and reduction breaks as shown here.
common_consumer
-> producer
or consumer
-> producer
if common_consumer does not exist) is safetransform replay is giving me error:
unknown file: Failure
C++ exception with description "it != replay_CasP.getReplay().end() INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/transform_replay.cpp":471, please report a bug to PyTorch. Could not find axis, iS{128}rf, requested in replay.
Exception raised from replayCasP at ../torch/csrc/jit/codegen/cuda/transform_replay.cpp:471
Repro cpp test:
void testGPU_FusionReductionComputeAtRepro() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
fusion.addOutput(tv1);
TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
tv1->split(1, 128);
auto tv2 = tv1->rFactor({1});
tv0->computeAt(tv1, 2); // this is giving me error!
tv0->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv0->axis(-1)->parallelize(ParallelType::TIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::cout << cdg.str() << std::endl;
}
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
auto tv1 = broadcast(tv0, {true, false, false});
auto tv2 = sum(tv1, {1});
std::cout << tv1 << std::endl;
std::cout << tv2 << std::endl;
produces:
T1[ bS{1}, iS{i1}, iS{i3} ]
T2[ iS{1}, rS{i1}, iS{i3} ]
but should produce:
T1[ bS{1}, iS{i1}, iS{i3} ]
T2[ bS{1}, rS{i1}, iS{i3} ]
Today we don't support a reduction that results in a single scalar value, as the current infrastructure would see this as resulting in a zero-dimensional tensor.
We should first support zero-dimensional tensors, meaning be able to create them and pipe them through the code generator so it's implicitly recognized as a single scalar value and we can generate code for them.
This may include:
It was noted the the weirdness of combining the Fusion IR creation and the CudaKernel object creation as they are two separate events and a CudaKernel doesn't necessarily have to be generated.
A common pattern in tests is:
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
It would be better to have:
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
<sometime later>
//At kernel Generation time.
torch::jit::fuser::cuda::CudaKernel prog(std::move(fusion));
This requires modifying the CudaKernel
Object declaration in kernel_cache.h
to include an explicit constructor that takes a unique_ptr
.
What is expected behavior when a single Val is computed at multiple sites? Here's a simple example:
TensorView* t0 = makeDummyTensor(2);
fusion.addInput(t0);
auto t1 = unaryOp(UnaryOpType::Exp, t0);
auto t2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), t1);
auto t3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), t1);
auto t4 = add(t2, t3);
fusion.addOutput(t4);
t1->computeAt(t2, -1);
At the computeAt
, this is what I got:
C++ exception with description "Tried to access position 1 in domain: [ iS{i1} ]
Exception raised from axis at ../torch/csrc/jit/codegen/cuda/tensor_view.cpp:73 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7f3bb891ed78 in /home/nmaruyama/pytorch/src/csarofeen.2/build/lib/libc10.so)
Is this a user error? If not, what is the expected result?
Some of the functions developed for gridReduce
can be also used in blockReduce
, which would make it a little clearer.
Un-prioritized list of things that generally should be done:
See testGPU_FusionReduction5
in https://github.com/naoyam/pytorch/tree/blockReduce_fail.
It fails at the final validation.
unknown file: Failure
C++ exception with description "Expected aten_output.allclose(cg_output) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Exception raised from testGPU_FusionReduction5 at ../test/cpp/jit/test_gpu.cpp:2645 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7fd99738ec08 in /home/nmaruyama/pytorch/src/csarofeen/build/lib/libc10.so)
frame #1: torch::jit::testGPU_FusionReduction5() + 0x97a (0x55bb04c93e0a in ./build/bin/test_jit)
frame #2: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x4a (0x55bb04cc277a in ./build/bin/test_jit)
frame #3: <unknown function> + 0x21f4de (0x55bb04cb84de in ./build/bin/test_jit)
frame #4: <unknown function> + 0x21f99d (0x55bb04cb899d in ./build/bin/test_jit)
frame #5: <unknown function> + 0x21fbbd (0x55bb04cb8bbd in ./build/bin/test_jit)
frame #6: testing::internal::UnitTestImpl::RunAllTests() + 0xc59 (0x55bb04cb9a99 in ./build/bin/test_jit)
frame #7: testing::UnitTest::Run() + 0x98 (0x55bb04cb9d98 in ./build/bin/test_jit)
frame #8: main + 0xc8 (0x55bb04b43148 in ./build/bin/test_jit)
frame #9: __libc_start_main + 0xe7 (0x7fd996bd5b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #10: _start + 0x2a (0x55bb04b4b83a in ./build/bin/test_jit)
" thrown in the test body.
[ FAILED ] JitTest.GPU_FusionReduction5_CUDA (1424 ms)
[----------] 1 test from JitTest (1424 ms total)
The same test does pass when numel_z
is larger like that defined in https://github.com/naoyam/pytorch/tree/reduction3d.
Here's the generated code:
__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 1> T1){
T1[ ( blockIdx.x * T1.stride[0] ) ]
= float(0);
float T3[1];
if ( ( ( ( 0 * 8 ) + threadIdx.y ) < T0.size[1] ) ) {
T3[ 0 ]
= float(0);
}
for(size_t i38 = 0; i38 < ( ceilDiv(T0.size[1], 8) ); ++i38 ) {
float T2[1];
if ( ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) && ( ( ( 0 * 128 ) + threadIdx.x ) < T0.size[2] ) ) ) {
T2[ 0 ]
= float(0);
}
for(size_t i40 = 0; i40 < ( ceilDiv(T0.size[2], 128) ); ++i40 ) {
if ( ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) && ( ( ( i40 * 128 ) + threadIdx.x ) < T0.size[2] ) ) ) {
T2[ 0 ]
= T2[ 0 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( i38 * 8 ) + threadIdx.y ) * T0.stride[1] ) + ( ( ( i40 * 128 ) + threadIdx.x ) * T0.stride[2] ) ];
}
}
if ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) ) {
T3[ 0 ]
= T3[ 0 ]
+ T2[ 0 ];
}
}
blockReduce< true, true, false > ( T1[ ( blockIdx.x * T1.stride[0] ) ], T3[ 0 ], reduction_add_float);
}
While I have not confirmed, my suspect is that it is because T3
is not initialized for threads whose threadIdx.y
is larger than T0.size[1]
.
The initialization guard is generated with simpler tests using 2D tensors, but I don't see the validation error. Since T3
is just a local variable, it is probably just zero most of the times, but can be a random value sometimes.
The scenario is that the reduction is on the fastest changing axis: [X, >>Y<<]. Then the following binding is attempted: [bidx{X}, tidx{>>Y<<}]. An assert is thrown not allowing this for non-constant access. The scenario is that I am using the ExpressionEvaluator to determine the reduction axis and I have determined it to be safe to bind based on the value.
Error:
C++ exception with description "Reductions can only be parallelized across dimensions of compile-time known constants.
Exception raised from parallelize at ../torch/csrc/jit/codegen/cuda/ir_internal_nodes.h:295 (most recent call first):
If I comment out the assert at ir_internal_nodes.h:295
, the kernel does generate correctly. The lingering question is safety.
void parallelize(ParallelType t) {
parallel_method_ = t;
// Currently a limitation as we allocate shared memory as static (not based
// off a dynamic size.)
/****
if (isReduction())
if (isThreadDim())
TORCH_CHECK(
extent()->isConstScalar(),
"Reductions can only be parallelized across dimensions of compile-time known constants.");
****/
This done on the 20_6_11_devel
branch.
Here is the test code:
void testGPU_FusionRedFailsOnTidxBind() {
int bid_x = 3;
int tid_x = 2;
int red_dim = 1;
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
fusion.addOutput(tv1);
tv1->axis(-2)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
prog.device_ = 0;
prog.grid(bid_x);
prog.block(tid_x);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({bid_x, tid_x}, options);
at::Tensor cg_output = at::empty({bid_x}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
GPULower gpulw(&fusion);
gpulw.printKernel(std::cout);
auto aten_output = input.sum({red_dim});
std::cout << aten_output << std::endl;
std::cout << cg_output << std::endl;
TORCH_CHECK(aten_output.allclose(cg_output),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}
This might as well not be a bug but just my misunderstanding is totally off. (This has happened already too many times). But discussion might still be educational (to me at least).
I found that computeAt
propagates bindings from layers past the nested source to input.
This is the simplified description to illustrate the idea.
T0 -> T2 -> T1
,T2
is the intermediate generated from T1->rFactor
T0->computeAt(T2)
and T1->axis(0)->parallelize(ParallelType::BIDx);
The strange behavior here is:
codegen seems to be able to implicitly propagate the BIDx
binding to axis(0) all the way back to T0
, even though it is not explicitly specified via computeAt;
However, BIDx
binding to axis(0) is not considered for T2 during allocation. It complains about non-constant allocation.
T2->computeAt(T1)
, which if added would have explicit thread binding propagated from T1 back to T2 and T0, and it would resolve the issue.I think from the observed behavior, we probably have an assert that could be relaxed on conditions.
void testGPU_FusionReductionJ() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
fusion.addOutput(tv1);
TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
tv1->split(1, 128);
auto tv2 = tv1->rFactor({1});
std::cout << "fusion 1: \n" << fusion << std::endl;
tv0->computeAt(tv2, 2);
std::cout << "fusion 2: \n" << fusion << std::endl;
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv0->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
// If we can figure out the thread binding for tv0 in the generated kernel without specifying it here, we should be able to do the same thing for the allocation of tv2 as well.
//tv0->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(0)->parallelize(ParallelType::BIDx);
std::cout << "fusion 3: \n" << fusion << std::endl;
GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::cout << cdg.str() << std::endl;
}
I am doing a reduction on not the fastest changing axis: [>>X<<, Y]. I am trying to split Y into two parts [>>X<<, Ya, Yb]. I am binding blocks to Ya and threads to Yb. When I do this the first output gets calculated correctly but all the rest are zero. The kernel looks incorrect in that has a thread predicate where it uses T1 index1 size even though T1 is only 1D.
Tensor<float, 1> T1
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 1> T1){
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
= float(0);
}
for(size_t i48 = 0; i48 < T0.size[0]; ++i48 ) {
if ( ( ( ( blockIdx.x * 2 ) + threadIdx.x ) < T1.size[1] ) ) {
T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
= T1[ ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T1.stride[0] ) ]
+ T0[ ( i48 * T0.stride[0] ) + ( ( ( blockIdx.x * 2 ) + threadIdx.x ) * T0.stride[1] ) ];
}
}
}
For Tensor dimension [16, 3, 2], this yields 6 outputs.
ATEN Output:
5.5575
9.6090
7.6456
9.6341
9.0592
7.0243
Codegen Output:
5.5575
0.0000
0.0000
0.0000
0.0000
0.0000
I am using 20_6_11_devel
branch.
Here is the test code:
void testGPU_FusionNonRedAxisBind() {
int bid_x = 3;
int tid_x = 2;
int red_dim = 0;
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
fusion.addOutput(tv1);
tv1->split(-1, tid_x);
tv1->axis(-2)->parallelize(ParallelType::BIDx);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
prog.device_ = 0;
prog.grid(bid_x);
prog.block(tid_x);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand({16, bid_x * tid_x}, options);
at::Tensor cg_output = at::empty({bid_x * tid_x}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
GPULower gpulw(&fusion);
gpulw.printKernel(std::cout);
auto aten_output = input.sum({red_dim});
std::cout << aten_output << std::endl;
std::cout << cg_output << std::endl;
TORCH_CHECK(aten_output.allclose(cg_output),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}
On my Ubuntu 18.04 / Cuda 10.2 / GTX 1070, I see the following test failure:
[ RUN ] JitTest.GPU_FusionReduction3_CUDA
unknown file: Failure
C++ exception with description "Expected t5.allclose(cg_output) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Exception raised from testGPU_FusionReduction3 at ../test/cpp/jit/test_gpu.cpp:2524 (most recent call first):
frame #0: <unknown function> + 0xe99a7 (0x7f19417529a7 in /home/lemo/work/pytorch/build/lib/libc10.so)
frame #1: std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>::operator()() const + 0x4c (0x55cf8b604b78 in bin/test_jit)
frame #2: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x40 (0x7f1941751c3c in /home/lemo/work/pytorch/build/lib/libc10.so)
frame #3: torch::jit::testGPU_FusionReduction3() + 0xb57 (0x55cf8b5e7062 in bin/test_jit)
frame #4: torch::jit::JitTest_GPU_FusionReduction3_CUDA_Test::TestBody() + 0x11 (0x55cf8b42d10f in bin/test_jit)
frame #5: void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x65 (0x55cf8b63f175 in bin/test_jit)
frame #6: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x5a (0x55cf8b639945 in bin/test_jit)
frame #7: testing::Test::Run() + 0xd2 (0x55cf8b618b0c in bin/test_jit)
frame #8: testing::TestInfo::Run() + 0xf3 (0x55cf8b619463 in bin/test_jit)
frame #9: testing::TestCase::Run() + 0x104 (0x55cf8b619ae4 in bin/test_jit)
frame #10: testing::internal::UnitTestImpl::RunAllTests() + 0x2a6 (0x55cf8b6248ce in bin/test_jit)
frame #11: bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x65 (0x55cf8b64023c in bin/test_jit)
frame #12: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x5a (0x55cf8b63a6a5 in bin/test_jit)
frame #13: testing::UnitTest::Run() + 0xc0 (0x55cf8b623352 in bin/test_jit)
frame #14: <unknown function> + 0x1be738 (0x55cf8b42c738 in bin/test_jit)
frame #15: main + 0x18a (0x55cf8b42c655 in bin/test_jit)
frame #16: __libc_start_main + 0xe7 (0x7f1940efbb97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #17: _start + 0x2a (0x55cf8b42c34a in bin/test_jit)
" thrown in the test body.
[ FAILED ] JitTest.GPU_FusionReduction3_CUDA (303 ms)
void testGPU_FusionReductionSchedulerUnrollRepro() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
bool fp16 = true;
int grid_dim_x_ = 40;
int grid_dim_y_ = 1;
int block_dim_x_ = 32;
int block_dim_y_ = 16;
int red_dim = 1;
int dim0 = 640;
int dim1 = 1024;
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float));
fusion->addInput(tv0);
torch::jit::fuser::Val* tv0_cast = nullptr;
if (fp16) {
tv0_cast = castOp(DataType::Float, tv0);
}
TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), (fp16 ? tv0_cast->as<TensorView>() : tv0));
TensorView* tv1_cast = nullptr;
if (fp16) {
tv1_cast = castOp(DataType::Half, tv1);
}
fusion->addOutput((fp16 ? tv1_cast : tv1));
// --------- Magic Scheduler Start --------
tv1->split(1, block_dim_x_);
// Unroll a certain number of rFactored elements
tv1->split(1, 4);
// Split Grid dimension to get multiple reds per block
tv1->split(0, block_dim_y_);
auto tv1_rf = tv1->rFactor({-3,-2});
tv1_rf->computeAt(tv1, 1);
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDy);
tv1->axis(-1)->parallelize(ParallelType::TIDx);
tv1_rf->axis(1)->parallelize(ParallelType::TIDy);
tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);
tv1_rf->axis(-2)->parallelize(ParallelType::Unroll);
// --------- Magic Scheduler End --------
if(fp16) {
// Find the rFactor generated Tensor
for(auto &use : fusion->unordered_uses(tv0_cast)) {
tv0_cast->as<TensorView>()->computeAt(use->output(0)->as<TensorView>(), -3);
}
// Unrolling binding for cast
tv0_cast->as<TensorView>()->axis(-2)->parallelize(ParallelType::Unroll);
tv0_cast->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
// Multtiple Reductions per Block are performed
tv1_cast->split(0, block_dim_y_);
tv1_cast->axis(-1)->parallelize(ParallelType::TIDy);
tv1_cast->axis(0)->parallelize(ParallelType::BIDx);
}
fusion->setLaunchConfig(
LaunchConfigType::TIDx, new Int(block_dim_x_));
fusion->setLaunchConfig(
LaunchConfigType::TIDy, new Int(block_dim_y_));
fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(grid_dim_x_));
fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(grid_dim_y_));
fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
fusion->printMath();
GPULower gpulw(fusion);
gpulw.printKernel(std::cout);
std::cout << std::flush << std::endl;
prog.setDevice(0);
torch::jit::fuser::cuda::compileKernel(&prog);
auto options = at::TensorOptions().dtype((fp16 ? at::kHalf : at::kFloat)).device(at::kCUDA, 0);
at::Tensor input = at::rand({dim0, dim1}, options);
at::Tensor cg_output = at::empty({(red_dim == 0 ? dim1 : dim0) }, options);
torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);
at::Tensor aten_output = input.sum({red_dim});
TORCH_CHECK(
aten_output.allclose(cg_output),
"Error of: ",
aten_output.sub(cg_output).abs().max());
}
Generated Output:
T1[ iS{( ceilDiv(i1, 16) )}, iS{16}, iS{( ceilDiv(( ceilDiv(i3, 32) ), 4) )}, iU{4}, ithreadIdx.x{32} ] compute_at( T4, 3 )
= __half2float(T0[ iS{i1}, iS{i3} ]);
T4[ iS{( ceilDiv(i1, 16) )}, ithreadIdx.y{16}, rS{( ceilDiv(( ceilDiv(i3, 32) ), 4) )}rf, rU{4}rf, ithreadIdx.x{32}rf ] compute_at( T2, 1 ) = reduction( T1[ iS{( ceilDiv(i1, 16) )}, iS{16}, iS{( ceilDiv(( ceilDiv(i3, 32) ), 4) )}, iU{4}, ithreadIdx.x{32} ] compute_at( T4, 3 ), op = add, initial value = float(0) )
T2[ iblockIdx.x{gridDim.x}, ithreadIdx.y{16}, rthreadIdx.x{32} ] = reduction( T4[ iS{( ceilDiv(i1, 16) )}, ithreadIdx.y{16}, rS{( ceilDiv(( ceilDiv(i3, 32) ), 4) )}rf, rU{4}rf, ithreadIdx.x{32}rf ] compute_at( T2, 1 ), op = add, initial value = float(0) )
T3[ iblockIdx.x{gridDim.x}, ithreadIdx.y{16} ]
= __float2half(T2[ iblockIdx.x{gridDim.x}, ithreadIdx.y{16}, rthreadIdx.x{32} ]);
__device__ void reduction_add_float(float& a, const float b) {
a = a + b;
}
__global__ void CUDAGeneratedKernel(Tensor<__half, 2> T0, Tensor<__half, 1> T3){
__shared__ float shared_mem[1024];
float T2[1];
if ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) ) {
T2[ 0 ]
= float(0);
}
float T4[1];
if ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) && ( ( ( ( ( 0 * 4 ) + 0 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) ) {
T4[ 0 ]
= float(0);
}
for(size_t i55 = 0; i55 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i55 ) {
float T1[4];
if ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) && ( ( ( ( ( i55 * 4 ) + ( 4 - 1 ) ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) ) {
for(size_t i56 = 0; i56 < 4; ++i56 ) {
T1[ i56 ]
= __half2float(T0[ ( ( ( blockIdx.x * 16 ) + threadIdx.y ) * T0.stride[0] ) + ( ( ( ( ( i55 * 4 ) + i56 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]);
}
} else {
for(size_t i56 = 0; i56 < 4; ++i56 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) && ( ( ( ( ( i55 * 4 ) + i56 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) ) {
T1[ i56 ]
= __half2float(T0[ ( ( ( blockIdx.x * 16 ) + threadIdx.y ) * T0.stride[0] ) + ( ( ( ( ( i55 * 4 ) + i56 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]);
}
}
}
for(size_t i58 = 0; i58 < 4; ++i58 ) {
if ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) && ( ( ( ( ( i55 * 4 ) + i58 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) ) {
T4[ 0 ]
= T4[ 0 ]
+ T1[ i58 ];
}
}
}
if ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) ) {
blockReduce< true, false, false > ( T2[ 0 ], T4[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
}
if ( ( ( ( ( blockIdx.x * 16 ) + threadIdx.y ) < T3.size[0] ) && ( threadIdx.x == 0 ) ) ) {
T3[ ( ( ( blockIdx.x * 16 ) + threadIdx.y ) * T3.stride[0] ) ]
= __float2half(T2[ 0 ]);
}
}
From Godbolt, I was curious what the Unrolled Loop looks like, I took some perf liberties.
for(size_t i55 = 0; i55 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i55 ) {
float T1[4];
#pragma unroll
for(size_t i56 = 0; i56 < 4; ++i56 ) {
T1[ i56 ]
= __half2float(T0[ ( ( ( blockIdx.x * 16 ) + threadIdx.y ) * T0.stride[0] ) + ( ( ( ( ( i55 * 4 ) + i56 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]);
}
#pragma unroll
for(size_t i58 = 0; i58 < 4; ++i58 ) {
T4[ 0 ]
= T4[ 0 ]
+ T1[ i58 ];
}
}
add.s64 %rd72, %rd71, %rd14;
ld.global.u16 %rs17, [%rd72];
{ cvt.f32.f16 %f41, %rs17;}
add.s64 %rd73, %rd72, %rd14;
ld.global.u16 %rs18, [%rd73];
{ cvt.f32.f16 %f42, %rs18;}
add.s64 %rd74, %rd73, %rd14;
ld.global.u16 %rs19, [%rd74];
{ cvt.f32.f16 %f43, %rs19;}
add.s64 %rd75, %rd74, %rd14;
ld.global.u16 %rs20, [%rd75];
{ cvt.f32.f16 %f44, %rs20;}
add.f32 %f57, %f56, %f41;
add.f32 %f58, %f57, %f42;
add.f32 %f59, %f58, %f43;
add.f32 %f60, %f59, %f44;
The logic in Expr::sameAs() is known to be incomplete ("binaryOp, unaryOp, and ternaryOp won't evaluate correctly here")
The scenario is as follows:
TV2[X, >>Y<<] = reductionOp
TV3[X, <<Y>>] = broadcast Op(TV2, {false, true})
TV4[X, Yout, Yin] = add(TV3, TV4) (this is generated) by a split.
When I use TV3->computeAt(TV4, -1)
this works!
If I do TV3->split(-1, 32)
and appropriately bind blocks and threads to TV3
, I get a failure because the intermediate TV3
attempts to create a dynamic intermediary even though for broadcast I only need 1 element.
Bad Schedule:
void testGPU_FusionBad() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* input_tv0 = makeDummyTensor(3);
TensorView* input_tv1 = makeDummyTensor(3);
fusion.addInput(input_tv0);
fusion.addInput(input_tv1);
TensorView* sum_tv2 =
reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
TensorView* output_tv4 = div(input_tv1, bcast_tv3);
sum_tv2->split(-1, 32);
TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
bcast_tv3->split(-1, 32);
output_tv4->split(-1, 32);
sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
output_tv4->axis(0)->parallelize(ParallelType::BIDx);
sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
output_tv4->axis(1)->parallelize(ParallelType::BIDy);
sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
fusion.addOutput(output_tv4);
fusion.printMath();
prog.device_ = 0;
prog.grid(32, 32);
prog.block(32);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({32, 32, 128}, options);
at::Tensor t1 = at::randn({32, 32, 128}, options);
at::Tensor cg_output = at::empty({32, 32, 128}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
}
Error:
CUDA NVRTC compile error: default_program(505): error: function call must have a constant value in a constant expression
Algo Exprs:
T5[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rS{( ceilDiv(i5, 32) )}rf, ithreadIdx.x{32}rf ] = reduction( T0[ iS{i1}, iS{i3}, iS{i5} ], op = add, initial value = float(0) )
T2[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rthreadIdx.x{32} ] = reduction( T5[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rS{( ceilDiv(i5, 32) )}rf, ithreadIdx.x{32}rf ], op = add, initial value = float(0) )
T3[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, bS{( ceilDiv(1, 32) )}, bthreadIdx.x{32} ]
= T2[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, rthreadIdx.x{32} ];
T4[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, iS{( ceilDiv(i11, 32) )}, ithreadIdx.x{32} ]
= T1[ iS{i7}, iS{i9}, iS{i11} ]
/ T3[ iblockIdx.x{gridDim.x}, iblockIdx.y{gridDim.y}, bS{( ceilDiv(1, 32) )}, bthreadIdx.x{32} ];
Kernel:
__global__ void kernel(Tensor<float, 3> T0, Tensor<float, 3> T1, Tensor<float, 3> T4){
__shared__ float shared_mem[1024];
float T3[( ceilDiv(1, 32) )];
float T2[1];
T2[ 0 ]
= float(0);
float T5[1];
if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T0.size[2] ) ) {
T5[ 0 ]
= float(0);
}
for(size_t i48 = 0; i48 < ( ceilDiv(T0.size[2], 32) ); ++i48 ) {
if ( ( ( ( i48 * 32 ) + threadIdx.x ) < T0.size[2] ) ) {
T5[ 0 ]
= T5[ 0 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i48 * 32 ) + threadIdx.x ) * T0.stride[2] ) ];
}
}
blockReduce< true, false, false > ( T2[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
if ( ( threadIdx.x == 0 ) ) {
T3[ 0 ]
= T2[ 0 ];
}
for(size_t i51 = 0; i51 < ( ceilDiv(T4.size[2], 32) ); ++i51 ) {
if ( ( ( ( i51 * 32 ) + threadIdx.x ) < T4.size[2] ) ) {
T4[ ( blockIdx.x * T4.stride[0] ) + ( blockIdx.y * T4.stride[1] ) + ( ( ( i51 * 32 ) + threadIdx.x ) * T4.stride[2] ) ]
= T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) + ( ( ( i51 * 32 ) + threadIdx.x ) * T1.stride[2] ) ]
/ T3[ 0 ];
}
}
}
Good schedule for comparison:
void testGPU_FusionGood() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* input_tv0 = makeDummyTensor(3);
TensorView* input_tv1 = makeDummyTensor(3);
fusion.addInput(input_tv0);
fusion.addInput(input_tv1);
TensorView* sum_tv2 =
reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
TensorView* output_tv4 = div(input_tv1, bcast_tv3);
sum_tv2->split(-1, 32);
TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
output_tv4->split(-1, 32);
bcast_tv3->computeAt(output_tv4, {-1});
sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
output_tv4->axis(0)->parallelize(ParallelType::BIDx);
sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
output_tv4->axis(1)->parallelize(ParallelType::BIDy);
sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
fusion.addOutput(output_tv4);
fusion.printMath();
prog.device_ = 0;
prog.grid(32, 32);
prog.block(32);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({32, 32, 128}, options);
at::Tensor t1 = at::randn({32, 32, 128}, options);
at::Tensor cg_output = at::empty({32, 32, 128}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0,t1}, {cg_output});
}
A type assert is mysteriously failing when I have a fusion with a castOp
. I briefly scanned through the code in arith.cpp
looks like the type is properly set.
I have the repro here: https://github.com/csarofeen/pytorch/tree/castOp_repro
(basically just trying to print fusion with a castOp).
build and run with ./test_jit --gtest_filter="*GPU_FusionCast*
C++ exception with description "dtype_ != DataType::Null INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp":134, please report a bug to PyTorch. Value does not have a data type.
Exception raised from getDataType at ../torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp:134 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fa9c29ef99b in /volume/codegen_project/pytorch_codegen/build/lib/libc10.so)
frame #1: torch::jit::fuser::Val::getDataType() const + 0x3d4 (0x7fa9c5b1de44 in /volume/codegen_project/pytorch_codegen/build/lib/libtorch_cuda.so)
frame #2: torch::jit::fuser::IRPrinter::handle(torch::jit::fuser::UnaryOp const*) + 0x1ad (0x7fa9c5b3d4dd in /volume/codegen_project/pytorch_codegen/build/lib/libtorch_cuda.so)
frame #3: void torch::jit::fuser::Expr::constDispatch<torch::jit::fuser::OptInConstDispatch*>(torch::jit::fuser::OptInConstDispatch*, torch::jit::fuser::Expr const*) + 0xb0 (0x7fa9c5ae76c0 in /volume/codegen_project/pytorch_codegen/build/lib/libtorch_cuda.so)
frame #4: torch::jit::fuser::IRPrinter::handle(torch::jit::fuser::Fusion*) + 0x55 (0x7fa9c5b3a745 in /volume/codegen_project/pytorch_codegen/build/lib/libtorch_cuda.so)
frame #5: torch::jit::fuser::operator<<(std::ostream&, torch::jit::fuser::Fusion*) + 0x6a (0x7fa9c5b3caaa in /volume/codegen_project/pytorch_codegen/build/lib/libtorch_cuda.so)
frame #6: torch::jit::testGPU_FusionCastOps() + 0x8f (0x5605261e181f in ./test_jit)
frame #7: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x4a (0x56052623d05a in ./test_jit)
frame #8: <unknown function> + 0x224fc5 (0x560526232fc5 in ./test_jit)
frame #9: <unknown function> + 0x2255d5 (0x5605262335d5 in ./test_jit)
frame #10: <unknown function> + 0x225885 (0x560526233885 in ./test_jit)
frame #11: testing::internal::UnitTestImpl::RunAllTests() + 0xc1c (0x5605262348dc in ./test_jit)
frame #12: testing::UnitTest::Run() + 0x98 (0x560526234b98 in ./test_jit)
frame #13: main + 0xc8 (0x560526081528 in ./test_jit)
frame #14: __libc_start_main + 0xe7 (0x7fa9c1c91b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #15: _start + 0x2a (0x56052608d30a in ./test_jit)
" thrown in the test body.
In the presence of a Reduction and an Operation like an Add that is fused in an rFactor
loop of the reduction in conjunction with an unroll. The Tensor from the Operation that is supposed to be fused is not allocated the size of the unroll.
I might suspect that when the Fused Op computeAt()
is applied, it is not picking up the Unroll
.
I will note I also show a similar situation where the Unroll
is properly applied between two operations where the second operation is not a Reduction.
This is the important part:
for(size_t i36 = 0; i36 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i36 ) {
float T1[1];
if ( ( ( ( ( ( i36 * 4 ) + ( 4 - 1 ) ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
for(size_t i37 = 0; i37 < 4; ++i37 ) {
T1[ i37 ]
= T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
+ float(0);
T3[ 0 ]
= T3[ 0 ]
+ T1[ i37 ];
}
Full kernel generated:
__global__ void kernel(Tensor<float, 2> T0, Tensor<float, 1> T2){
__shared__ float shared_mem[1024];
T2[ ( blockIdx.x * T2.stride[0] ) ]
= float(0);
float T3[1];
if ( ( ( ( ( ( 0 * 4 ) + 0 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T3[ 0 ]
= float(0);
}
for(size_t i36 = 0; i36 < ( ceilDiv(( ceilDiv(T0.size[1], 32) ), 4) ); ++i36 ) {
float T1[1];
if ( ( ( ( ( ( i36 * 4 ) + ( 4 - 1 ) ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
for(size_t i37 = 0; i37 < 4; ++i37 ) {
T1[ i37 ]
= T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
+ float(0);
T3[ 0 ]
= T3[ 0 ]
+ T1[ i37 ];
}
} else {
for(size_t i37 = 0; i37 < 4; ++i37 ) {
if ( ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T1[ i37 ]
= T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * T0.stride[1] ) ]
+ float(0);
}
if ( ( ( ( ( ( i36 * 4 ) + i37 ) * 32 ) + threadIdx.x ) < T0.size[1] ) ) {
T3[ 0 ]
= T3[ 0 ]
+ T1[ i37 ];
}
}
}
}
blockReduce< true, false, false > ( T2[ ( blockIdx.x * T2.stride[0] ) ], T3[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
}
void testGPU_FusionUnrollBug2() {
const std::vector<int64_t> tensor_dims_in = {128, 128};
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
fusion->addInput(tv0);
TensorView* tv1 = add(tv0, new Float(0));
TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv1);
fusion->addOutput(tv2);
const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand(tensor_dims_in, options);
at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options);
//const at::ArrayRef<c10::IValue> inputs({input});
// Schedule
tv2->split(1, 32);
tv2->split(1, 4); // unroll
auto tv2_rf = tv2->rFactor({-3, -2});
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv2_rf->axis(0)->parallelize(ParallelType::BIDx);
tv2_rf->axis(-1)->parallelize(ParallelType::TIDx);
tv2_rf->axis(-2)->parallelize(ParallelType::Unroll);
tv1->computeAt(tv2_rf, -1);
prog.setDevice(0);
fusion->setLaunchConfig(LaunchConfigType::TIDx, new Int(tensor_dims_in[0]));
fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(32));
fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);
}
void testGPU_FusionUnrollBug() { [406/1910]
const std::vector<int64_t> tensor_dims_in = {128, 128};
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(tensor_dims_in.size());
fusion->addInput(tv0);
TensorView* tv1 = add(tv0, new Float(0));
TensorView* tv2 = add(tv1, new Float(0));
fusion->addOutput(tv2);
const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::rand(tensor_dims_in, options);
at::Tensor cg_output = at::empty(tensor_dims_in, options);
//const at::ArrayRef<c10::IValue> inputs({input});
// Schedule
tv2->split(1, 32);
tv2->split(1, 4); // unroll
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(-2)->parallelize(ParallelType::Unroll);
tv1->computeAt(tv2, -1);
prog.setDevice(0);
fusion->setLaunchConfig(LaunchConfigType::TIDx, new Int(tensor_dims_in[0]));
fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(32));
fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1));
fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt);
}
This test generates invalid code:
(https://github.com/naoyam/pytorch/blob/compute-at-bug/test/cpp/jit/test_gpu.cpp)
void testGPU_FusionComputeAtBug() {
torch::jit::fuser::cuda::CudaKernel prog;
prog.setFusionPtr(std::make_unique<Fusion>());
Fusion* fusion = prog.fusion();
FusionGuard fg(fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
fusion->addInput(tv0);
TensorView* tv1 = mul(tv0, new Float(1));
TensorView* tv2 = add(tv0, new Float(2));
TensorView* tv3 = add(tv1, new Float(3));
TensorView* tv4 = add(tv1, new Float(4));
fusion->addOutput(tv2);
fusion->addOutput(tv3);
fusion->addOutput(tv4);
std::cout << "tv1->computeAt(tv3, 1)\n";
tv1->computeAt(tv3, -1);
fusion->printMath();
fusion->printKernel();
fusion->printMath();
prog.setDevice(0);
prog.grid(1);
prog.block(1);
torch::jit::fuser::cuda::compileKernel(&prog);
}
Here's the generated kernel:
__global__ void kernel(Tensor<float, 2> T0, Tensor<float, 2> T2, Tensor<float, 2> T3, Tensor<float, 2> T4){
for(size_t i36 = 0; i36 < T4.size[0]; ++i36 ) {
for(size_t i37 = 0; i37 < T4.size[1]; ++i37 ) {
float T1[1];
T1[ 0 ]
= T0[ ( i36 * T0.stride[0] ) + ( i37 * T0.stride[1] ) ]
* float(1);
T4[ ( i36 * T4.stride[0] ) + ( i37 * T4.stride[1] ) ]
= T1[ 0 ]
+ float(4);
}
}
for(size_t i39 = 0; i39 < T4.size[0]; ++i39 ) {
for(size_t i40 = 0; i40 < T4.size[1]; ++i40 ) {
T2[ ( i39 * T2.stride[0] ) + ( i40 * T2.stride[1] ) ]
= T0[ ( i39 * T0.stride[0] ) + ( i40 * T0.stride[1] ) ]
+ float(2);
}
}
for(size_t i41 = 0; i41 < T4.size[0]; ++i41 ) {
for(size_t i42 = 0; i42 < T4.size[1]; ++i42 ) {
T3[ ( i41 * T3.stride[0] ) + ( i42 * T3.stride[1] ) ]
= T1[ 0 ]
+ float(3);
}
}
}
Notice that the loop for T3
references T1
, but that's only defined in the first loop nest. The T3
loop should be actually placed in the same loop nest as T4
, but it's "blocked" by the T2
loop. We had a similar issue recently, which we fixed by sorting inputs of expressions when traversing fusion expressions (see #112). This issue is similar, but happens with output expressions.
The goal is to make a "magic scheduler" that takes an algorithm with a reduction op and applies a TensorExpression schedule to match the performance of Pytorch's ATen TensorIterator.
My plan and heuristic are shown in this (NVIDIA Internal) document: https://docs.google.com/document/d/15b8JSnLYu9PIGwEltPXeKOoX5XR_EE_RMjPkRtQ8EHo/
Work is happening on this branch:
https://github.com/csarofeen/pytorch/tree/20_6_11_devel_redsched
Evaluation is happening with this code base:
https://github.com/kevinstephano/codegen_perf
It seems that in some cases broadcast domains are propagated through subsequent expressions. Here's a simple example that results in an exception at the addOutput
in the second block:
https://github.com/naoyam/pytorch/blob/bcast-domain/test/cpp/jit/test_gpu.cpp#L3371
Here's what I get:
[ RUN ] JitTest.GPU_FusionBCastDomain_CUDA
unknown file: Failure
C++ exception with description "T3[ iS{i1}, bS{1} ] cannot be registered as an output as it has a broadcast axis.
Exception raised from addOutput at ../torch/csrc/jit/codegen/cuda/fusion.cpp:154 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7f0f70bdcd78 in /home/nmaruyama/pytorch/src/csarofeen.3/build/lib/libc10.so)
frame #1: torch::jit::fuser::Fusion::addOutput(torch::jit::fuser::Val*) + 0x254 (0x7f0f745566e4 in /home/nmaruyama/pytorch/src/csarofeen.3/build/lib/libtorch_cuda.so)
frame #2: torch::jit::testGPU_FusionBCastDomain() + 0x112 (0x555f8991e732 in ./build/bin/test_jit)
frame #3: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x4a (0x555f8995260a in ./build/bin/test_jit)
frame #4: <unknown function> + 0x21836e (0x555f8994836e in ./build/bin/test_jit)
frame #5: <unknown function> + 0x21882d (0x555f8994882d in ./build/bin/test_jit)
frame #6: <unknown function> + 0x218a4d (0x555f89948a4d in ./build/bin/test_jit)
frame #7: testing::internal::UnitTestImpl::RunAllTests() + 0xc59 (0x555f89949929 in ./build/bin/test_jit)
frame #8: testing::UnitTest::Run() + 0x98 (0x555f89949c28 in ./build/bin/test_jit)
frame #9: main + 0xc8 (0x555f897e1c08 in ./build/bin/test_jit)
frame #10: __libc_start_main + 0xe7 (0x7f0f70423b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #11: _start + 0x2a (0x555f897eadba in ./build/bin/test_jit)
" thrown in the test body.
[ FAILED ] JitTest.GPU_FusionBCastDomain_CUDA (1 ms)
Note that although they are similar, the first block works fine.
I want to update broadcast
to something like TensorView* broadcast(TensorView* inp, const std::vector<enum BroadcastingType>& is_broadcast_dim)
. We can have the BroadcastingType
to be EXPAND
, INSERT
and NONE
, where:
a. INSERT
and NONE
would be equivalent to the true
and false
in our current API
b. EXPAND
will be a new rule that marks the conversion of a size-1 dimension to broadcasting dimension.
Updated broadcast
can support converting size-1 dimension to broadcasting dimension in order to do this:
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({x, y}, options);
at::Tensor t1 = at::randn({1, y}, options);
auto t2 = t0.add(t1);
Currently our broadcast wrapper takes a vector of bool to indicate where we should insert a broadcasting dimension, but we cannot update an existing dimension to be broadcasting dimension.
In the example above, we cannot mark t1:axis(0) (which is a size 1 dimension) to broadcasting dimension.
I go this hacky code to compile after removing some assert. Not surprisingly the index_compute is wrong (picking the wrong stride).
{
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
//TensorView* tv2 = broadcast(tv1, {BroadcastingType::EXPAND, BroadcastingType::None});
std::vector<IterDomain*> out_domain;
out_domain.push_back(new IterDomain(new Int(0), new Int(1), ParallelType::Serial, false, false, true));
out_domain.push_back(tv1->axis(1));
TensorView* tv2 =
new TensorView(new TensorDomain(out_domain), tv1->getDataType().value());
new BroadcastOp(tv2, tv1);
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv0->computeAt(tv3, -1);
tv2->computeAt(tv3, -1);
constexpr int x = 63, y = 33;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({x, y}, options);
at::Tensor t1 = at::randn({1, y}, options);
at::Tensor cg_output = at::empty({x, y}, options);
prog.device_ = 0;
prog.grid(x);
prog.block(y);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
auto t2 = t1.add(t0);
TORCH_CHECK(t2.allclose(cg_output));
}
Alternatively we can fake this in integration code. Where I could register the input tensor with reduced rank and use the existing broadcast
API to fill in that rank.
In this example above, I can register at::Tensor t1 = at::randn({1, y}, options);
as
TensorView* tv1_ = makeDummyTensor(1);
TensorView* tv1 = broadcast(tv1_, {true, false});
fusion.addInput(tv1_);
However, this could be messy as we might have need to maintain version/alias of the same tensor to accommodate different broadcasting requirements on the same input tensor.
at::Tensor t0 = at::randn({ y, z}, options);
at::Tensor t1 = at::randn({ 1, z}, options);
at::Tensor t2 = at::randn({x,1, z}, options);
auto t3 = t0.add(t1);
atuo t4 = t1.add(t2);
auto t5 = t3.add(t2);
fusion.addOutput(t4);
fusion.addOutput(t5);
In the example above, both t1
and t2
needs to be broadcasted to [x,y,z] eventually because we have t5 = t0 + t1 + t2
; Meanwhile, we have the other output t4 = t1 + t2
, which would ended up with a broadcast domain in an output tensor. So we would need to maintain two copies of t1
and t2
(with axis(y
) broadcasted and constant 1 respectively).
From the computeAt work, one thing we assume right now is that all tensors will only be produced once. This is typically a good default, but does not support everything we may need to do. For example if we can have persistent softmax, this would be fine, but non-persistent softmax would require recomputing some values.
We need a mechanism where we can specify the use of a tensor in an expression, and replicate expressions from its origin to inputs, so that it will be recomputed, with an interface that returns all recomputed tensors.
This will be functionality that we will add to computeAt, so if computeAt recognizes it can't satisfy a given computeAt call it can have an option for the user where it will automatically recompute all conflicting tensors.
The JIT stores int64 and double precision values for any scalar type of int/float. We should do the same, modifying our Int and Float IR nodes, and making sure our printing in ir_iostream is printed at the higher precision.
This segfaults.
Fusion fusion_test;
fusion_test.exprs(true);
I run into this when I was calling hasReduction
as I progressively parse the graph.
IrGraphGenerator is missing handlers for:
Bool
Half
TernaryOp
BroadcastOp
ReductionOp
The issue is back-to-back reduces write to an intermediate result that may not have the correct result.
void reduction(int trials, int bidx, int bidy, int tidx, int tidy, int unroll, int elems) {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* input_tv0 = makeDummyTensor(3);
fusion.addInput(input_tv0);
TensorView* sum_val_tv1 = reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
sum_val_tv1->split(-1, tidy);
sum_val_tv1->split(-2, tidx);
sum_val_tv1->split(-3, unroll);
TensorView* sum_val_rf_tv2 = sum_val_tv1->rFactor({-4});
TensorView* sum_val_rf_tv3 = sum_val_tv1->rFactor({-3});
TensorView* sum_val_rf_tv4 = sum_val_tv1->rFactor({-2});
sum_val_rf_tv2->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv3->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv4->axis(0)->parallelize(ParallelType::BIDx);
sum_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
sum_val_rf_tv2->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv3->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv4->axis(1)->parallelize(ParallelType::BIDy);
sum_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
sum_val_rf_tv2->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_rf_tv3->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_rf_tv4->axis(-2)->parallelize(ParallelType::TIDx);
sum_val_tv1->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv2->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv3->axis(-1)->parallelize(ParallelType::TIDy);
sum_val_rf_tv4->axis(-1)->parallelize(ParallelType::TIDy);
fusion.addOutput(sum_val_tv1);
__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 2> T1){
T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ]
= float(0);
float T4[1];
T4[ 0 ]
= float(0);
float T3[1];
T3[ 0 ]
= float(0);
float T2[4];
for(size_t i37 = 0; i37 < 4; ++i37 ) {
if ( ( ( ( ( ( ( ( 0 * 4 ) + i37 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {
T2[ i37 ]
= float(0);
}
}
for(size_t i38 = 0; i38 < ( ceilDiv(( ceilDiv(( ceilDiv(T0.size[2], 16) ), 32) ), 4) ); ++i38 ) {
for(size_t i39 = 0; i39 < 4; ++i39 ) {
if ( ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) < T0.size[2] ) ) {
T2[ i39 ]
= T2[ i39 ]
+ T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( ( ( ( ( i38 * 4 ) + i39 ) * 32 ) + threadIdx.x ) * 16 ) + threadIdx.y ) * T0.stride[2] ) ];
}
}
}
for(size_t i41 = 0; i41 < 4; ++i41 ) {
T3[ 0 ]
= T3[ 0 ]
+ T2[ i41 ];
}
blockReduce< true, false, false > ( T4[ 0 ], T3[ 0 ], reduction_add_float);
blockReduce< false, true, false > ( T1[ ( blockIdx.x * T1.stride[0] ) + ( blockIdx.y * T1.stride[1] ) ], T4[ 0 ], reduction_add_float);
}
Why is this assertion required?
In the softmax test I'm currently working on, it seems the assertion needs to be disabled:
https://github.com/csarofeen/pytorch/pull/100/files#diff-32ba7310cdb76c3dde2adab830e365f3R140-R149
Broadcast is not allowed on output TensorView. There's a proper check to detect that which spits out [output_tv] cannot be registered as an output as it has a broadcast axis
The code snippet below would trigger a TORCH_CHECK
to fail.
TensoView *t1 = makeDummyTensor(1);
TensorView *t2 = broadcast(t1, {false, true});
fusion.addInput(t1);
fusion.addOutput(t2);
// ...
If we explicitly mark t2
to be broadcasted on dimension 1, we are assuming its corresponding stride to be 0 (because broadcasted elements map to the same physical memory location).
However, as t2
is an output tensor its stride will be an input to the kernel (the generated kernel will be something like below):
void kernel(Tensor<float, 1> T1, Tensor<float, 2> T2) {
// ...
}
Hence marking the I/O TensorView as broadcasting violates the contract that I/O Tensors are provided at runtime. We can't generate a safe kernel that would behave as the user of the generated code would expect.
If later we found out use cases that broadcasting on output could save memory bandwidth of generated kernel, we could revisit this topic.
TensorView* tv0 = makeDummyTensor(2);
fusion.addInput(tv0);
// reduce then broadcast
auto tv1 = sum(tv0, {0});
auto tv2 = broadcast(tv1, {false, true});
generates:
T1[ rS{i1}, iS{i3} ]
T2[ rS{i1}, bS{1} ]
but should be:
T1[ rS{i1}, iS{i3} ]
T2[ iS{i3}, bS{1} ]
IrGraphGenerator fails to pass with the current develop branch.
[ RUN ] JitTest.GPU_IrGraphGenerator_CUDA
unknown file: Failure
C++ exception with description "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = 0x5590b6d326e0 of tensor 0x5590bf5f0cb0
Exception raised from newForReduction at ../torch/csrc/jit/codegen/cuda/arith.cpp:394 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7f7f9c409d78 in /home/nmaruyama/pytorch/src/csarofeen.3/build/lib/libc10.so)
frame #1: torch::jit::fuser::reductionOp(torch::jit::fuser::BinaryOpType, std::vector<int, std::allocator<int> > const&, torch::jit::fuser::Val*, torch::jit::fuser::TensorView*) + 0x1705 (0x7f7f9fd70be5 in /home/nmaruyama/pytorch/src/csarofeen.3/build/lib/libtorch_cuda.so)
frame #2: torch::jit::testGPU_IrGraphGenerator() + 0x23a (0x559080494a0a in ./build/bin/test_jit)
This is likely because of an added check at #126.
Right now, we require the allocation size to be a constant. Since we bind the input sizes, we can relax this limitation with a bit of help from the expression evaluator:
A basic solution would just add up all the individual allocations to calculate the scratch buffer size. A more sophisticated version could track the lifetimes of the scratch buffers and reuse the same space if they don't overlap (ie. pack the allocations)
We could start with the simple solution, and add the packing in follow up iteration.
Here's a generated loop nest:
for(size_t i101 = 0; i101 < T6.size[0]; ++i101 ) {
for(size_t i102 = 0; i102 < T6.size[1]; ++i102 ) {
T2[ ( i102 * T2.stride[0] ) + ( i103 * T2.stride[1] ) ]
= T1[ ( i101 * T1.stride[0] ) + ( i102 * T1.stride[1] ) ];
}
}
Compilation fails as i103
is undefined.
The nest is generated from a modified softmax test. See naoyam@c38d89f#diff-538e381bc787011cc53c15caf54e126dR2770. It basically does the same arithmetic operations as the original softmax test, but for a debugging purpose, all the parallelization-related tensor operations are removed.
Run testGPU_FusionSoftmax
modified in this commit: naoyam@c38d89f
This fails:
auto one = new Int(1);
auto one3 = mul(mul(one, one), one);
TORCH_CHECK(one3->isConstScalar());
It seems class ConstCheck fails to identify constness when Val is a compound expression.
See #177 for a reproducer.
TensorView::computeAt() is propagating to all the uses of the same producer:
The problem is that Fusion::uses() returns a std::set<Expr*>:
std::set normally offers a predictable (sorted) order of values, but in this case the keys are Expr*, which are non-deterministic. This is an issue since the order of computeAt() traversal changes the end results.
This issue manifested as a failure of a unit test triggered by an unrelated change in a different unit test (but which impacted the heap layout as a side effect).
The testGPU_FusionSoftmax test in this branch generates invalid code:
https://github.com/naoyam/pytorch/tree/codegen-bug
Here's the generated kernel:
__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 3> T9){
__shared__ float shared_mem[1024];
float T6[1];
float T5[1];
T5[ 0 ]
= float(0);
float T11[1];
if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T11[ 0 ]
= float(0);
}
float T2[1];
float T1[1];
T1[ 0 ]
= float(0);
float T10[1];
if ( ( ( ( 0 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T10[ 0 ]
= float(0);
}
for(size_t i84 = 0; i84 < ( ceilDiv(T9.size[2], 32) ); ++i84 ) {
if ( ( ( ( i84 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T10[ 0 ]
= fmaxf(T10[ 0 ]
, T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i84 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]);
}
}
blockReduce< true, false, false > ( T1[ 0 ], T10[ 0 ], reduction_fmaxf_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
if ( ( threadIdx.x == 0 ) ) {
T2[ 0 ]
= T1[ 0 ];
}
for(size_t i87 = 0; i87 < ( ceilDiv(T9.size[2], 32) ); ++i87 ) {
float T7[1];
if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T7[ 0 ]
= T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i87 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]
- T2[ 0 ];
}
float T8[1];
if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T8[ 0 ]
= expf(T7[ 0 ]);
}
float T3[1];
if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T3[ 0 ]
= T0[ ( blockIdx.x * T0.stride[0] ) + ( blockIdx.y * T0.stride[1] ) + ( ( ( i87 * 32 ) + threadIdx.x ) * T0.stride[2] ) ]
- T2[ 0 ];
}
float T4[1];
if ( ( ( ( i87 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T4[ 0 ]
= expf(T3[ 0 ]);
}
}
for(size_t i93 = 0; i93 < ( ceilDiv(T9.size[2], 32) ); ++i93 ) {
if ( ( ( ( i93 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T11[ 0 ]
= T11[ 0 ]
+ T4[ 0 ];
}
}
blockReduce< true, false, false > ( T5[ 0 ], T11[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
if ( ( threadIdx.x == 0 ) ) {
T6[ 0 ]
= T5[ 0 ];
}
for(size_t i96 = 0; i96 < ( ceilDiv(T9.size[2], 32) ); ++i96 ) {
if ( ( ( ( i96 * 32 ) + threadIdx.x ) < T9.size[2] ) ) {
T9[ ( blockIdx.x * T9.stride[0] ) + ( blockIdx.y * T9.stride[1] ) + ( ( ( i96 * 32 ) + threadIdx.x ) * T9.stride[2] ) ]
= T8[ 0 ]
/ T6[ 0 ];
}
}
}
T8
is used in the final loop nest but is defined and computed in a different loop nest.
The final loop nest corresponds to this expression:
TensorView* output_tv6 = div(exp_tv4_2, bcast_sum_tv6);
Interestingly, swapping the two operands, i.e., div(bcast_sum_tv6, exp_tv4_2)
, seems to be fine. See https://github.com/naoyam/pytorch/blob/d2913c9fb2c7e94c786ce3b73a6b5ab87b4eee3b/test/cpp/jit/test_gpu.cpp#L2798. The fuser generates code as expected, although they are not the same computation anymore.
Note that this issue was encountered in the block broadcast PR (#100) but shows up with the dev branch as well.
It's really annoying to work with Arith and have it always return a Val*
when we frequently use TensorView*
as args. We should overload these operators to return a TensorView*
when there is one in the inputs.
We should then modify our tests which frequently have syntax like:
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(3.0)));
to remove the static_casts.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.