Writing GPU kernels in a dynamic language: how does CUDA.jl work?
Julia's CUDA.jl is famous for letting you run pure Julia on GPUs. This seems impossible, because Julia feels like a scripting language, but it works as advertised: fast and simple with no DSLs or C++ required [0].
Here's an example where we write and execute a CUDA kernel that takes the product of two vectors in pure Julia.
julia> using CUDA
julia> n = 1000
julia> a = CUDA.rand(n) # Float32[0.821652, 0.8169186, 0.23189153, ...]
julia> b = CUDA.rand(n) # Float32[0.086756594, 0.2620363, 0.33393615, ...]
julia> prod = CUDA.zeros(Float32, n) # Float32[0.0, 0.0, 0.0, ...]
julia> function prod_kernel(a, b, output) # define a normal Julia function
i = threadIdx().x # get our index
output[i] = a[i] * b[i]
return # required due to a quirk in CUDA.jl.
end
prod_kernel (generic function with 1 method)
julia> @cuda threads=n prod_kernel(a, b, prod) # launch prod_kernel on your GPU
CUDA.HostKernel{...}
julia> prod[1:3] # now contains the product:
Float32[0.07128373, 0.21406232, 0.07743697]
How does this work? We can break running Julia on a GPU into three steps.
- Compilation, in which we convert our Julia function to GPU machine code. We retarget Julia's LLVM backend to emit a GPU instruction set called PTX in GPUCompiler.jl, and then use the
ptxas
program to convert that to GPU machine code in CUDA.jl. - Data Conversion, in which we convert types to be GPU-friendly with a library called Adapt.jl.
- Execution, in which we execute the kernel and retrieve the result. This boils down a few FFI calls like
cuLaunchKernel
, which CUDA.jl calls directly.
More concretely, the entry point of CUDA.jl is the @cuda
macro, which launches the given function call on the GPU. Here's a cut down version of what's inside.
macro cuda(f, var_exprs...)
kernel_f = cudaconvert(f)
kernel_args = map(cudaconvert, var_exprs...)
kernel_tt = Tuple{map(Core.Typeof, kernel_args)...}
kernel = cufunction(kernel_f, kernel_tt; compiler_kwargs...)
kernel(var_exprs...; call_kwargs...)
end
The three main players here are cufunction
, cudaconvert
, and the final kernel
invocation. Each handles one of the steps we discussed above: cufunction
compiles the kernel to GPU machine code, cudaconvert
converts the arguments from CPU-friendly types to GPU-friendly types, and the kernel
call launches the kernel.
Compilation (GPUCompiler.jl
)
NVIDIA GPUs run machine code from an instruction set called SASS, which is generated from a higher level instruction set called PTX. Julia code is usually compiled via LLVM to x86, but conveniently, LLVM can also generate PTX.
So broadly, here's the strategy: we replicate the normal Julia compilation process on our kernel
function, but tell LLVM to target the PTX backend. Then we use NVIDIA's ptxas
utility to convert that to SASS and generate the final machine code. Finally we upload that machine code to the GPU. This all happens inside cufunction
, which is responsible for taking the Julia code inside the kernel function and converting it to GPU machine code.
The most important parts of cufunction
happen inside cufunction_compile
and cufunction_link
. Here's a simplified cufunction_compile
.
function cufunction_compile(@nospecialize(job::CompilerJob), ctx)
# lower to PTX
mi, mi_meta = GPUCompiler.emit_julia(job) # lower the julia to an IR
ir, ir_meta = GPUCompiler.emit_llvm(job, mi; ctx) #lower that to LLVM IR
asm, asm_meta = GPUCompiler.emit_asm(job, ir; format=LLVM.API.LLVMAssemblyFile)
write("/tmp/input.ptx", asm)
run_and_collect("ptxas --verbose --output-file /tmp/tmp.cubin /tmp/input.ptx")
image = read("/tmp/tmp.cubin")
return image
end
cufunction_compile
first lowers the Julia into a typed intermediate representation with emit_julia
. Then it compiles that to LLVM IR using emit_llvm
, and finally it converts that LLVM IR to PTX with emit_asm
. These three emit
functions are the three horsemen of GPUCompiler, and we'll see them again and again.
Before cufunction_compile
returns, it also runs ptxas
, which converts PTX ("parallel thread execution;" a virtual ISA which runs on many NVIDIA GPUs) into SASS ("streaming assembler;" a hardware ISA that will run on your specific GPU). Those lowered instructions live in a .cubin
file, which we return.
cufunction_link
takes that .cubin
loads it onto the GPU inside a module. Modules are the GPU-equivalent of DLLs, and we can upload them via a FFI call to cuModuleLoadDataEx
.
If you're paying close attention, you might wonder where in this process we include the functions that NVIDIA gives us (e.g. atan
). The answer is that NVIDIA gives us a library called libdevice.bc
, which is full definitions for functions like atan
(double @__nv_atan(double %x)
). Interestingly, NVIDIA distributes these as LLVM bitcode [2], so that LLVM can optimize with knowledge of these functions when we generate our PTX.
Most of the magic happened in this section. The takeaway is that Julia exploits fact that it already compiles through LLVM, and LLVM can already emit PTX, which can be easily converted to GPU machine code.
With compilation out of the way, let's look at cudaconvert
, which handles converting the data to be GPU-friendly.
Conversion
There are two conversion steps that must happen when we invoke a kernel on an Array
. In the first step, we copy the data from the CPU to the GPU. This converts the Array
to a CuArray
. In the second step, we convert the types into a GPU-friendly format. This converts the CuArray
into a CuDeviceArray
.
First, let's examine the copying conversion. We start this conversion with a manual call like CuArray(cpu_array)
. Here's that CuArray
implementation:
@inline function CuArray{T,N,B}(xs::AbstractArray{<:Any,N}) where {T,N,B}
A = CuArray{T,N,B}(undef, size(xs))
copyto!(A, convert(Array{T}, xs))
return A
end
We see that CuArray
wraps a call to copyto!
, which in turn checks some bounds and then calls unsafe_copyto!
. That makes a FFI call to cuMemcpyHtoDAsync_v2
to actually execute the host-to-device copy.
You might wonder why CUDA.jl doesn't do this conversion automatically. The answer is that copying data to the GPU can be quite expensive, and by leaving it explicit, CUDA.jl lets the programmer manage the details of the copies more tightly.
Next, let's examine the the device conversion, which maps CuArray
s to CuDeviceArray
s. Morally, CuArrays
are a mechanism that the CPU uses to copy and keep track of data that lives on the GPU; in contrast, CuDeviceArray
s only ever exist on the GPU. Concretely, here's the difference between the two types:
mutable struct CuArray{T,N,B} <: AbstractGPUArray{T,N}
storage::Union{Nothing,ArrayStorage{B}}
offset::Int
maxsize::Int
dims::Dims{N}
end
struct CuDeviceArray{T,N,A} <: DenseArray{T,N}
ptr::LLVMPtr{T,A}
len::Int
maxsize::Int
dims::Dims{N}
end
You can see that CuArrays hold their data storage object, whereas CuDeviceArray
s hold only a LLVM pointer to GPU memory.
Anyway, when we want to run a kernel, we need to convert these CuArray
s to GPU-usable CuDeviceArray
s. Unlike the copying conversion, this conversion happens automatically inside cudaconvert
, via the following unsafe_convert
call.
function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArray{T,N}) where {T,N}
CuDeviceArray{T,N,AS.Global}(reinterpret(LLVMPtr{T,AS.Global}, pointer(a)), size(a),
a.maxsize - a.offset*Base.elsize(a))
end
We can see that this is mainly an accounting conversion; we get a pointer to the CuArray's data with pointer(a)
and then convert it to an LLVM pointer with global address space using reinterpret
. The Global
address space is distinct from the Default
address space, and is basically how we communicate to LLVM that our CuDeviceArray
pointer must point to global memory (accessible from both the GPU and the CPU) as opposed to the Default
kind of memory, which might end up being some variant of CPU- or GPU-local.
We've now seen the two critical conversions: from Array
to CuArray
to CuDeviceArray
. Both of these conversions are easy if you just have a few Array
s lying around, but what if you have a T<U<Array>, Array>
that you want to send to the GPU as a T<U<CuArray>, CuArray>
? Or similar with CuArray
and CuDeviceArray
?
This nested types problem is neatly handled by a library called Adapt.jl. To use it, we write two functions: adapt_structure
for the recursive case (unwrap the T
s and U
s), and adapt_storage
for the base case (actually convert the underlying Array
s). Here's the implementation of adapt_storage
that converts Array
to CuArray
; it's nothing to write home about.
Adapt.adapt_storage(::Type{CuArray}, xs::AT) where {AT<:AbstractArray} =
isbitstype(AT) ? xs : convert(CuArray, xs)
Execution
Now that we've compiled the kernel to machine code and converted and done both conversions, we need to actually launch the kernel. This step is pretty simple. The kernel(...)
call eventually hits cudacall
, which is given here:
function cudacall(f, types::Type, args...; kwargs...)
convert_arguments(types, args...) do pointers...
launch(f, pointers...; kwargs...)
end
end
convert_arguments
passes arguments through Base.cconvert
and Base.unsafe_convert
, and launch
is a reasonably direct wrapper around a FFI call to cuLaunchKernel
, which actually starts the computation.
Extensions
Now you know the basics of how CUDA.jl compiles normal Julia code onto GPUs; it's a combination of GPUCompiler.jl, which allows us to retarget LLVM to emit PTX, Adapt.jl, which allows us to deep-convert types, and a bunch of FFI calls to manage the GPU state.
So far CUDA.jl has been the star of the show, but it turns out that the interface exposed by GPUCompiler.jl is so useful that it's spawned a whole ecosystem around it. Some of this ecosystem is related to GPU programming, but other pieces of it seem completely unrelated.
Morally, I think of GPUCompiler.jl as exposing functionality that allows us to easily compile Julia to any target LLVM supports.
AMDGPU.jl, oneAPI.jl, Metal.jl, and VectorEngine.jl
Most directly, the fact that GPUCompiler is separated from CUDA.jl makes it easier for Julia to support a wide range of accelerator backends. At least AMDGPU.jl (AMD), oneAPI.jl (Intel), Metal.jl (Apple), and VectorEngine.jl (NEC) all benefit from sharing the GPUCompiler.jl core internally.
This is a pretty direct generalization from CUDA.jl, except we target amdgcn-amd-amdhsa
or air64-apple-macosx
or some other accelerator triple.
KernelAbstractions.jl
On top of all of these accelerator backend, there's a neat library called KernelAbstractions.jl, which provides primitives for writing platform-agnostic kernels that you can CUDA, ROCm, oneAPI, or just your CPU.
I think this library is easiest to understand via examples; here's a transpose function from the docs, where you can see that @index
is proving a platform-agnostic replacement for calls to threadIdx()
or another CUDA-specific intrinsic.
@kernel function naive_transpose_kernel!(a, b)
i, j = @index(Global, NTuple)
@inbounds b[i, j] = a[j, i]
end
I recommend reading the histogram_kernel!
example for a more complete tour of the primitives that KernelAbstractions.jl provides, including @index
, @localmem
, @uniform
, and @synchronize
.
StaticCompiler.jl
So far we've seen pretty standard consumers of GPUCompiler; here's a more unusual one. StaticCompiler.jl lets you save and load precompiled Julia functions. From the docs:
fib(n) = n <= 1 ? n : fib(n - 1) + fib(n - 2)
fib_compiled, path = compile(fib, Tuple{Int}, "fib")
# In a new session
fib_compiled = load_function("fib.cjl")
fib_compiled(10) # -> 55
How does it work? When you call compile
, StaticCompiler generates two files: an LLVM object file, and a serialized Julia object with some metadata. compile
calls generate_obj
, which I've summarized here.
function generate_obj(f, tt, path::String, ...)
GPUCompiler.codegen(...) # calls emit_julia and emit_llvm
table = relocation_table!(mod)
obj, _ = GPUCompiler.emit_asm(...)
open(obj_path, "w") do io
write(io, obj)
end
return table #ultimately written to metadata file
end
In this snippet, we can see StaticCompiler.jl using GPUCompiler.jl to target the normal CPU LLVM triple. The trick is that managing this compilation pipeline explicitly allows StaticCompiler.jl to take the resulting assembly and write it to disk instead of running it immediately. Then it can reload that assembly later, ready to execute!
The main additional complexity comes from the table
variable, which holds a mapping from any global variables that were inside the compiled function to corresponding LLVM variable slots. When StaticCompiler loads the function form disk, it'll rehydrate those globals using table
.
Extended Berkeley Packet Filters (eBPF)
Here's perhaps the wildest usage of GPUCompiler that I'm aware of. A package called BPFNative.jl allows Julia to target eBPF. If you haven't heard eBPF, you can think of eBPF programs as little scripts that run inside a virtual machine in the Linux kernel. They can modify the behavior of the kernel without the usual complexity and risks of dealing with LKMs or patching the kernel directly.
Since LLVM can emit eBPF [3], we can generate eBPF from Julia using GPUCompiler.
A central usage of BPF programs is to gather statistics. For example, a program called offwaketime
collects statistics about why threads blocked / slept / yielded, and how long it took until they were reawoken.
Here's a more in-depth explanation of offwaketime
, here's a standard implementation, and here's the same function in Julia, from BPFNative.jl
(the interesting code starts around line 105).
Conclusion
That's all--I hope you've enjoyed this peek into the machinery behind Julia's quasimagical CUDA ecosystem. In writing this post, I gained a lot of appreciation for the combined power of a LLVM language (Julia) and a package that exposes a sane view into the language's compilation machinery (GPUCompiler.jl). LLVM supports an unbelievable number of backends, and if you can efficiently reuse compilation machinery across them, you can target all sorts of unusual use cases.
If you're interested in further reading, I recommend the original paper for CUDA.jl and the GPUCompiler.jl source, which is surprisingly short. I also recommend the CUDA C programming guide; several of my confusions while writing ultimately boiled down to confusion about CUDA itself.
Finally, a big thanks to Valentin for answering my many annoying questions and reading drafts of this post, and to Prof. Edelman for his flexibility in letting me explore CUDA.jl for MIT's 18.S191.
Notes
[0] If you're wondering exactly what portion of Julia works on a GPU, here's a pretty useful GitHub issue. The most important and non-obvious limitation is that there can't be type instabilities in your GPU code. In other words, functions that might return different types based on their input values are not allowed. Generic functions like f(T) -> T
are fine.
[1] If you need a quick refresher, here's how we'd write a CUDA kernel the conventional way. For sake of example, we'll say that we're implementing an extension to PyTorch, according to the guidelines given here. Loosely paraphrased:
- Write your custom CUDA kernels in
.cu
files. - Declare those
.cu
functions in a C++ file. - Write some C++ functions that forward to those
.cu
functions. - Use
pybind11
to expose those C++ functions to Python. - Use the the
cpp_extension
package to compile the C++ sources withgcc
and the CUDA sources with NVIDIA’snvcc
. - Link them into one shared library that is available from Python code.
[2] Technically they distribute them as NVVM bitcode, but NVVM IR
is a strict subset of LLVM IR
. A valid NVVM program is always valid LLVM.
[3] One of my main takeaways from this post is just how many backends LLVM supports.