This vignette explains what
actually happens when you wrap a function with jit().
Understanding this is what lets you avoid common pitfals on
tracing-based compilers. If you haven’t yet, read the Get Started vignette first.
We will use the simple linear function as the running
example.
jit() worksIn pseudo-R, jit(f) returns roughly this closure:
jit <- function(f, static = character()) {
cache <- hashtab()
function(...) {
# abstract representations of dynamic args, values of static args
key <- input_signature(..., static = static)
if (is.null(cache[[key]])) {
# step 1: record prim_* calls into an AnvlGraph
graph <- trace_fn(f, list(...), static = static)
# step 2: lower to an XLA executable, store
cache[[key]] <- compile(graph)
}
# step 3: run the cached executable on this call's inputs
cache[[key]](...)
}
}Three things happen:
static arguments covered below), recording the
sequence of primitive operations into an intermediate representation
called an AnvlGraph.AnvlArrays and the actual values for static inputs.When a call’s inputs match a cached entry, both tracing and compilation are skipped. We can observe this directly by inspecting the size of the cache held inside the jitted function:
cache_size <- function(f) environment(f)$cache$size
linear_jit <- jit(linear)
cache_size(linear_jit) # 0: nothing cached yet
#> [1] 0
linear_jit(nv_scalar(2), nv_scalar(3), nv_scalar(1))
#> AnvlArray
#> 7
#> [ CPUf32{} ]
cache_size(linear_jit) # 1: a new entry was added
#> [1] 1
# same shapes -> cache hit, size unchanged
linear_jit(nv_scalar(2), nv_scalar(3), nv_scalar(5))
#> AnvlArray
#> 11
#> [ CPUf32{} ]
cache_size(linear_jit)
#> [1] 1
# different shapes -> a second entry is added
linear_jit(
nv_array(c(1, 2)),
nv_array(c(3, 4)),
nv_array(c(1, 1))
)
#> AnvlArray
#> 4
#> 9
#> [ CPUf32{2} ]
cache_size(linear_jit)
#> [1] 2Each input to the function contributes to the cache key differently,
depending on whether it is dynamic (an arrayish value, the
default) or static (marked via the static =
argument of jit()):
nv_aval(dtype, shape, ambiguous) triple.
Two arrays with the same abstract value but different data hit the same
cache entry.identical(). They stay as regular
R values during the compilation, but their value is fixed for a compiled
program. Two calls with flag = TRUE and
flag = FALSE therefore land on different cache
entries.In the snippet above all three inputs are dynamic, so the first two
calls share a key (three f32[] scalars) and hit the cache,
while the third call presents three f32[2] vectors and
forces a retrace.
You’ll want static = when the body of the function needs
to look at a concrete R value – typically a flag, a small integer, or a
string. It’s also the only way to do R-level input validation on a
value: a dynamic input is just a shape/dtype placeholder during tracing,
so a check like stopifnot(abs(sum(p) - 1) < 1e-6) –
verifying that p is a proper probability vector – only
works if p is static.
linear_maybe <- function(x, w, b, use_bias) {
if (use_bias) linear(x, w, b) else x * w
}
linear_maybe_jit <- jit(linear_maybe, static = "use_bias")
linear_maybe_jit(2, 3, 1, use_bias = TRUE)
#> AnvlArray
#> 7
#> [ CPUf32?{} ]
linear_maybe_jit(2, 3, use_bias = FALSE)
#> AnvlArray
#> 6
#> [ CPUf32?{} ]Each call with a new static value forces a re-trace and re-compile, so static arguments cause more re-compiles than dynamic ones. Only use them when you really need them.
The cache key also includes information about the input structure (if
AnvlArrays are nested in lists) or the device to compile
for, but we will ignore them as they are less relevant to understand for
using jit().
Two important notes to be aware of:
jit() once on a function and reuse the result;
calling jit() inside a loop creates a fresh cache on every
iteration and defeats the point:Tracing is how the R code is translated into a form that the
XLA compiler can understand. It works by replacing the dynamic inputs
with special (GraphBox) values and runnig it. Instead of
doing the array computations, this will instead record every primitive
operation in an AnvlGraph, which represents the
evaluation trace. For the purposes of this vignette you can
think of tracing and the subsequent XLA compilation as a single phase
that runs once per cache miss. Note that this is different from R’s
trace() function, which lets you insert code into
functions. In {anvl}, this tracing machinery is available via
trace_fn(). Although you will probably never need to use
this function directly, we will use it to show what’s happening under
the hood. Below, we trace linear with a length-3 vector for
x and scalars for w and b:
f32_scalar <- nv_aval("f32", integer())
f32_scalar
#> AbstractArray(dtype=f32, shape=)
f32_vec3 <- nv_aval("f32", 3)
f32_vec3
#> AbstractArray(dtype=f32, shape=3)
trace_fn(linear, args = list(x = f32_vec3, w = f32_scalar, b = f32_scalar))
#> <AnvlGraph>
#> Inputs:
#> %x1: f32[3]
#> %x2: f32[]
#> %x3: f32[]
#> Body:
#> %1: f32[3] = broadcast_in_dim [shape = 3, broadcast_dimensions = <any>] (%x2)
#> %2: f32[3] = mul(%x1, %1)
#> %3: f32[3] = broadcast_in_dim [shape = 3, broadcast_dimensions = <any>] (%x3)
#> %4: f32[3] = add(%2, %3)
#> Outputs:
#> %4: f32[3]The printed AnvlGraph is like an R function: it has
inputs, a body and outputs. However, the content is more structured:
Crucially, only prim_* calls (and the nv_*
API functions or overloaded operators that delegate to them) get
recorded. Any other R code in the body might influence what the
evaluation trace is, but is not present in the traced graph itself.
Tracing only produces correct results for pure functions – functions whose output depends on their arguments and nothing else, and that have no R-level side effects. This contract is a consequence of both how the caching works and that the compiled program runs outside of R and only communicates back to the R interpreter via it’s return values. Concretely, the function’s execution path – the specific sequence of primitive calls it performs – must depend only on:
The next subsections each show what tracing does to particular R code and where its behavior might be surprising.
Tracing runs your R code and records primitive operations in the
graph. Because for is not an anvl primitive, it will be
executed as usual and all the primitive calls encountered will be
recorded in the graph. Here we apply the linear function
n times.
linear_repeated <- function(x, w, b, n) {
for (i in seq_len(n)) x <- linear(x, w, b)
x
}
trace_fn(linear_repeated, args = list(x = f32_scalar, w = f32_vec3, b = f32_vec3, n = 2L))
#> <AnvlGraph>
#> Inputs:
#> %x1: f32[]
#> %x2: f32[3]
#> %x3: f32[3]
#> Body:
#> %1: f32[3] = broadcast_in_dim [shape = 3, broadcast_dimensions = <any>] (%x1)
#> %2: f32[3] = mul(%1, %x2)
#> %3: f32[3] = add(%2, %x3)
#> %4: f32[3] = mul(%3, %x2)
#> %5: f32[3] = add(%4, %x3)
#> Outputs:
#> %5: f32[3]The graph contains a single broadcast_in_dim (lifting
the scalar x to f32[3]) followed by two
mul/add pairs in sequence – not a loop
construct. The broadcast happens only once because by the second
iteration x is already an f32[3] and matches
w and b directly. The compiled executable will
contain those five operations laid out one after another. For a
different value of n, the loop would get unrolled for this
specific iteration number. Long loops also lead to long compile times
and large executables. You can instead use nv_while(),
which records a single higher-order primitive call in the graph
regardless of how many iterations the loop runs.
if statements pick one branchTracing runs only the if-branch that the condition
selects at trace time. One common scenario is where the branch depends
on a static input flag:
linear_maybe <- function(x, w, b, use_bias) {
if (use_bias) linear(x, w, b) else x * w
}
trace_fn(
linear_maybe,
args = list(x = f32_scalar, w = f32_scalar, b = f32_scalar, use_bias = TRUE)
)
#> <AnvlGraph>
#> Inputs:
#> %x1: f32[]
#> %x2: f32[]
#> %x3: f32[]
#> Body:
#> %1: f32[] = mul(%x1, %x2)
#> %2: f32[] = add(%1, %x3)
#> Outputs:
#> %2: f32[]The graph contains one mul and one add
operation, but no conditional. Tracing with
use_bias = FALSE would produce a graph containing the other
branch.
Where things go wrong is if the evaluation trace depends on something that does not influence the cache key, such as a value from the enclosing environment:
threshold <- 0.5
h <- function(x) {
if (threshold > 0.5) x * 2 else x + 1
}
trace_fn(h, args = list(x = f32_scalar))
#> <AnvlGraph>
#> Inputs:
#> %x1: f32[]
#> Body:
#> %1: f32[] = add(%x1, 1:f32?)
#> Outputs:
#> %1: f32[]The trace itself runs correctly, but it becomes a problem in
combination with jit()’s caching mechanism. The closed-over
threshold does not influence the cache key, so subsequent
calls with the same dynamic input type would reuse the executable
produced when threshold was 0.5, no matter how
threshold is at call time. The only fix is to not let the
graph depend on values outside the function’s signature – in this case,
make threshold an explicit static argument.
We just saw an extreme version of this in the if
section: a closed-over R variable picked the branch that ended
up in the graph. The same dynamic applies to closed-over values used as
plain operands – their value at trace time is read once and baked into
the graph.
Here we close over a default bias instead of taking it as an argument:
default_b <- 5
linear_default_b <- function(x, w) linear(x, w, default_b)
trace_fn(linear_default_b, args = list(x = f32_scalar, w = f32_scalar))
#> <AnvlGraph>
#> Inputs:
#> %x1: f32[]
#> %x2: f32[]
#> Body:
#> %1: f32[] = mul(%x1, %x2)
#> %2: f32[] = add(%1, 5:f32?)
#> Outputs:
#> %2: f32[]The graph contains add(%1, 5:f32?) – the value
5 is hard-wired into the program. Changing
default_b afterwards has no effect on the graph (and, once
the graph is compiled, no effect on the executable either).
The Tracing Contract above already noted that a jitted function must be pure. This subsection makes the consequence concrete: R-level side effects only have an effect while the graph is being built, not on subsequent calls.
A common R pattern for stateful objects is to wrap them in an environment, since environments give you reference semantics:
new_model <- function(beta) {
e <- new.env()
e$beta <- beta
e$grad_step <- function(beta_grad, lr) {
e$beta <- e$beta - beta_grad * lr
e$beta
}
e
}If this was executed in “standard R”, every call to
model$grad_step() would nudge model$beta one
step further along the gradient. But wrapping grad_step
with jit() breaks the function on two levels:
model <- new_model(nv_array(c(0, 0, 0), dtype = "f32"))
grad_step_jit <- jit(model$grad_step)
g <- nv_array(c(1, 1, 1), dtype = "f32")
grad_step_jit(g, 0.1) # expected c(-0.1, -0.1, -0.1)
#> AnvlArray
#> -0.1000
#> -0.1000
#> -0.1000
#> [ CPUf32{3} ]
grad_step_jit(g, 0.1) # expected c(-0.2, -0.2, -0.2) -- but identical to call 1
#> AnvlArray
#> -0.1000
#> -0.1000
#> -0.1000
#> [ CPUf32{3} ]
grad_step_jit(g, 0.1) # expected c(-0.3, -0.3, -0.3) -- but identical to call 1
#> AnvlArray
#> -0.1000
#> -0.1000
#> -0.1000
#> [ CPUf32{3} ]
class(model$beta) # not even an AnvlArray any more
#> [1] "GraphBox" "AnvlBox"Two things went wrong, both for the same reason: the only thing
jit() records into the graph is prim_*
calls.
e$beta <- ... is a plain R assignment,
not a prim_* call, so it doesn’t get recorded into the
AnvlGraph. It runs once, during tracing – but the value
being assigned is anvl’s internal trace placeholder (a
GraphBox), not a real array. So after the first call,
model$beta is broken: it holds a leaked tracer
instead of the array you expected.e$beta <- ... line will not be
executed again. And because e$beta was a closed-over R
value at trace time, its initial value (c(0, 0, 0)) is
baked into the graph as a literal (per Closed-over values become
constants above), so every call returns the same
-0.1.This is why jit()-compiled functions must be
pure: their output depends only on their arguments, and any
state updates have to happen at the call site, not inside the
function:
jit() argumentsBy default, a compiled executable treats its inputs as read-only: the R-visible input arrays remain valid after the call, and XLA has to allocate fresh memory for any output of matching shape. For long training loops over large parameters, this means every step allocates a new parameter buffer and leaves the previous one for the garbage collector.
Via the donate argument of jit(), you can
tell XLA that an input will not be used after the call, so it is free to
reuse the input array’s memory for an output.
step <- jit(function(w, g) w - 0.1 * g, donate = "w")
w <- nv_array(c(1, 2, 3), dtype = "f32")
g <- nv_array(c(0.1, 0.1, 0.1), dtype = "f32")
w <- step(w, g)
w
#> AnvlArray
#> 0.9900
#> 1.9900
#> 2.9900
#> [ CPUf32{3} ]The compiled executable now consumes w’s buffer as part
of the call. The caller must not reuse an array that was donated to a
function, otherwise an error is thrown:
Every AnvlArray lives on a concrete device (CPU, a
specific GPU, etc.), and a compiled executable is itself bound to a
specific device – so all inputs to one call must live on that same
device. How the device is chosen for a jit()-compiled call
depends on how you called jit(). The two most relevant
options are:
device =, anvl looks at the devices of the array inputs at
call time, requires them to agree, and compiles for that device. If
there are no array inputs, it falls back to the
default_device().jit() time. Passing a
concrete device – e.g. jit(f, device = "cpu:0") or
jit(f, device = nv_device("cuda:0")) – forces every call to
run on that device. Inputs living on a different device are copied over
automatically.See the documentation of the device arg in
help(jit) for more information.
add <- jit(function(x, y) x + y)
add_cpu <- jit(function(x, y) x + y, device = "cpu")
a <- nv_array(c(1, 2, 3), dtype = "f32")
b <- nv_array(c(4, 5, 6), dtype = "f32")
device(add(a, b)) # inferred from inputs
#> <CpuDevice(id=0)>
device(add_cpu(a, b)) # pinned
#> <CpuDevice(id=0)>Because the device is part of the cache key (see The compilation
cache above), a single jit()ted function can hold
compiled binaries for several devices at once, unless
device was specified explicitly during
jit().