This guide explains how to implement a new primitive. It will primarily focus on how to do this. See the internals vignette for more information on how primitives work.
In general, there are two main reasons to add a new primitive:
In order to be able to add a new primitive, it needs to be expressible in the {stablehlo} package. There are various scenarios:
All required operations are already implemented in the {stablehlo} R package.
\(\rightarrow\) This is the simplest case and the one we assume in this guide.
One or more required operations are not already implemented in the {stablehlo} R package.
\(\rightarrow\) You first need to implement the missing operations in the {stablehlo} R package. See this issue for which operations are missing and the StableHLO specification for which are available.
It requires shape dynamism (output shape depends on data, such as using boolean values for subsetting):
\(\rightarrow\) This is currently not possible, but we hope we can add support for this in the future, e.g. via a second, dynamic Fortran backend.
It cannot be expressed or can only expressed inefficiently:
\(\rightarrow\) You can implement a stableHLO custom call, see the custom print operation in pjrt. The {pjrt} package has a dedicated “Adding Custom Calls” tutorial. This mechanism is used, for example, to implement some of the linear algebra primitives such as SVD: on CPU these call into LAPACK, and on GPU they use NVIDIA’s cuSOLVER library. In the future it will also be possible to call into user-provided custom CUDA kernels, but there is no infrastructure for this yet – currently we only call into predefined kernels shipped with the CUDA toolkit.
Let’s add a new primitive step by step. We’ll implement
prim_repeat_along – a primitive that repeats an array
multiple times along a specified dimension.
For example, repeating c(1, 2, 3) twice along dimension
1 gives c(1, 2, 3, 1, 2, 3).
This primitive has a dynamic input (an array) and two static parameters (how many times to repeat and which dimension).
Primitives are created with new_primitive(): it builds
the AnvlPrimitive metadata object that holds the rules,
wraps the body with jit(), attaches the metadata, and
registers the result in the internal primitive registry. The returned
callable becomes the primitive and is bound to a
prim_<name> R symbol.
library(anvl)
prim_repeat_along <- new_primitive(
"repeat_along",
function(operand, times, dim) {
# type of operand is checked by graph_desc_add()
infer_fn <- function(operand, times, dim) {
if (!checkmate::test_integerish(dim, lower = 1, upper = ndims(operand), len = 1L)) {
cli::cli_abort("{.arg dim} must be between 1 and {ndims(operand)}, but is {.val dim}")
}
if (!checkmate::test_integerish(times, lower = 1, len = 1L)) {
cli_abort("times must be a positive integer, but is {times}")
}
new_shape <- shape(operand)
new_shape[dim] <- new_shape[dim] * times
list(AbstractArray(
dtype = dtype(operand),
shape = Shape(new_shape),
ambiguous = operand$ambiguous
))
}
graph_desc_add(
self, # lexically bound to the AnvlPrimitive
list(operand = operand), # Dynamic inputs (arrays)
params = list( # Static parameters
times = times,
dim = dim
),
infer_fn = infer_fn
)[[1L]] # Extract single output from list
},
static = c("times", "dim")
)The primitive is now callable directly as
prim_repeat_along(x, times, dim).
Key points:
self (the
[AnvlPrimitive]) as the first argument to
graph_desc_add(). new_primitive() installs
self into the enclosing environment of the body, so this is
always available inside a primitive.graph_desc_add() returns a list of outputs; use
[[1L]] for single-output primitives.ambiguous flag from inputs to outputs,
see type promotion for what this
means.Every argument that is not a dynamic array must be listed in
static. new_primitive calls jit()
with backend = "auto" internally, which defers the choice
of backend to call time.
If the primitive is a wrapper around a stablehlo operation, it is
possible to use the corresponding inference function from the stablehlo
package (such as stablehlo::infer_types_concatenate). When
doing so, you need to:
ValueTypes
using at2vt().ValueTypes.ValueTypes back to abstract arrays using
vt2at().ambiguous flag of the output depending on the
inputs (ambiguity is strictly an {anvl} concept, not a
stablehlo concept).Array constructors (e.g. prim_fill,
prim_iota) have no dynamic array inputs – every argument is
static. Two things change in this case:
static.
Otherwise {anvl} would try to interpret a scalar argument like
value or shape as an AnvlArray
input.backend = "auto" cannot infer a device from
inputs, because there are no array inputs. Instead, add a
device formal to the function and pass
device = device_arg("device") to
new_primitive(). At call time, the user-supplied device is
read from that argument and used to determine both the backend (via
[backend()] dispatch) and the compilation device.For example, the real definition of prim_fill looks like
this:
prim_fill <- new_primitive(
"fill",
function(value, shape, dtype, ambiguous = FALSE, device = NULL) {
# ... graph_desc_add(self, ...) ...
},
static = 1:5,
device = device_arg("device")
)See [device_arg()] for the detailed semantics.
prim_repeat_along has its own parameters and needs a
custom body. Many primitives are simpler – elementwise unary or binary
ops, reductions, and comparisons all share a standard shape.
R/primitives.R exposes factories that generate the body for
you:
make_unary_op(stablehlo_infer) – elementwise unary
(e.g. prim_abs, prim_negate).make_binary_op(stablehlo_infer) – elementwise binary
(e.g. prim_add, prim_mul).make_reduce_op(infer_fn) – reductions with
dims / drop parameters
(e.g. prim_reduce_sum).make_compare_op(direction) – comparison ops with a
fixed direction string (e.g. prim_eq,
prim_lt).Used together with the corresponding stablehlo inference function, the primitive definition collapses to a one-liner:
prim_add <- new_primitive("add", make_binary_op(stablehlo::infer_types_add))
prim_negate <- new_primitive("negate", make_unary_op(stablehlo::infer_types_negate))
prim_reduce_sum <- new_primitive("reduce_sum", make_reduce_op(), static = 2:3)Reach for the manual graph_desc_add() form when the
primitive has extra static parameters, a custom inference function, or
multiple outputs.
The StableHLO rule defines how to lower this primitive to actual
operations, which is used in the stablehlo() lowering pass.
We implement repeat_along using concatenation:
prim_repeat_along[["stablehlo"]] <- function(operand, times, dim) {
operands <- rep(list(operand), times)
list(rlang::exec(stablehlo::hlo_concatenate, !!!operands, dimension = dim - 1L))
}The rule receives:
stablehlo::FuncValues.It must return a list of stablehlo::FuncValues, even if
there is only one output.
Important: StableHLO uses 0-based indexing, while
{anvl} uses R’s 1-based indexing. Always convert dimension indices by
subtracting 1. Also note that in graph_desc_add() we are
converting the error messages from stablehlo to our 1-based indexing, so
you do not have to worry about that here.
If the operation should support automatic differentiation, attach a
reverse rule built with rule_reverse(). The idea here is
the following, where we assume the input operand has shape
(s_1, ..., s_n), which means that the output (and therefore
it’s gradient) has shape
(s_1, ..., s_{dim-1}, s_dim * times, s_{dim+1}, ..., s_n).
(s_1, ..., s_{dim-1}, s_dim, times, s_{dim+1}, ..., s_n).times dimension and drop the
times dimension.prim_repeat_along[["reverse"]] <- rule_reverse(function(inputs, outputs, grads, params, required) {
if (!required[[1L]]) {
return(list(NULL))
}
grad <- grads[[1L]]
operand <- inputs[[1L]]
dim <- params$dim
times <- params$times
old_shape <- shape(operand)
grad_shape <- shape(grad)
new_shape <- grad_shape
new_shape[dim] <- old_shape[dim]
new_shape <- append(new_shape, times, after = dim - 1L)
grad_reshaped <- prim_reshape(grad, new_shape)
grad_summed <- prim_reduce_sum(grad_reshaped, dims = dim, drop = TRUE)
list(grad_summed)
})The wrapped backward receives:
inputs: Input GraphValues from the forward
passoutputs: Output GraphValues from the
forward passgrads: Gradients flowing back from downstream (one per
output)params: Named list of the call’s static parameters
(here: params$dim, params$times)required: Logical vector indicating which input
gradients are neededIt returns a list with one gradient per input (or NULL
if not required).
Note that the reverse rule will only be called if at least one input gradient is required, so this can be assumed.
For most primitives the backward-only form above is enough. If the
gradient computation can be made significantly more efficient by running
an alternative forward pass that exposes intermediate values for reuse,
use the alternative-forward form:
rule_reverse(forward = ...). The forward hook receives
(inputs, params), emits whatever forward primitives it
likes via the normal prim_* callables, and returns a list
with the primal outputs and a backward closure.
Intermediates flow from forward to backward via R’s lexical scoping:
prim_<name>[["reverse"]] <- rule_reverse(forward = function(inputs, params) {
x <- inputs[[1L]]
y <- prim_some_op(x) # forward emit 1
z <- prim_other_op(x, y) # forward emit 2; both `x` and `y` captured
list(
outputs = list(z), # one box per original call output
backward = function(inputs, outputs, grads, params, required) {
# `x` is also available via `inputs`; what lexical capture buys us is
# access to intermediates like `y` that the framework doesn't pass in.
...
}
)
})The backward closure shares the same signature as the
backward-only form. Intermediates that aren’t passed in (like
y above) flow in via lexical capture. The forward must
return one output box per original call output, with matching
shape/dtype.
prim_<name>[["stablehlo"]] and
prim_<name>[["reverse"]] are the two rules you almost
always want. A primitive can optionally also carry a quickr
rule, which lowers it to plain R code for the quickr backend (see
R/rules-quickr.R). The quickr rule is only required if you
want the primitive to run under local_backend("quickr"); if
you skip it, the primitive will still work on the xla backend.
new_primitive() returns the callable directly and also
registers the primitive in the internal registry used by the graph
machinery, so no separate registration step is needed:
prim_repeat_along
#> function (operand, times, dim)
#> {
#> if (currently_tracing()) {
#> cl <- match.call()
#> cl[[1L]] <- f
#> return(eval.parent(cl))
#> }
#> args <- lapply(as.list(match.call())[-1L], eval, envir = parent.frame())
#> be <- if (!is.null(device_argname) && !is.null(args[[device_argname]])) {
#> dev_val <- args[[device_argname]]
#> if (is.character(dev_val))
#> default_backend()
#> else backend(dev_val)
#> }
#> else {
#> jit_auto_detect_backend(flatten(args[!names(args) %in%
#> static]))
#> }
#> if (is.null(jit_fns[[be]])) {
#> jit_fns[[be]] <<- do.call(jit_with_backend, c(list(f = f,
#> static = static, cache_size = cache_size, backend = be),
#> if (!is.null(device_argname)) {
#> list(device = device_arg(device_argname))
#> } else if (!is.null(device)) {
#> list(device = device)
#> }, dots))
#> }
#> do.call(jit_fns[[be]], args)
#> }
#> <environment: 0x56349f3f13e8>
#> attr(,"class")
#> [1] "JitPrimitive" "JitFunction"
#> attr(,"backend")
#> [1] "auto"
#> attr(,"primitive")
#> <AnvlPrimitive:repeat_along>nv_ API FunctionIn {anvl}, we also offer convenience wrappers around the primitives.
An example is prim_add vs nv_add, where the
latter calls into the former after optionally broadcasting (scalar)
inputs:
nv_add(1L, nv_array(2:3))
#> AnvlArray
#> 3
#> 4
#> [ CPUi32{2} ]
prim_add(1L, nv_array(2:3))
#> Error in `prim_add()`:
#> ! `lhs` and `rhs` must have the same tensor type.
#> ✖ Got tensor<i32> and tensor<2xi32>.In our case, no such convenience is needed and the functionality is
not too low-level (for it to be generally useful), so we can just
reassign the prim_* function to an nv_*
function:
Note that in the nv_* wrapper function, you can only
access certain properties of the input arrayish values via:
shape_abstract()ndims_abstract()dtype_abstract()ambiguous_abstract()If you, for example, use shape() instead of
shape_abstract(), your function won’t work with R literals.
I.e., <extract>_abstract() first converts the input
to an AbstractArray (if possible) and then extracts the
property.
You can now use the primitive, both eagerly and inside a larger JIT-compiled function:
x <- nv_array(c(1, 2, 3), shape = c(3, 1))
# Eager call -- works because prim_repeat_along is itself jit-compiled.
prim_repeat_along(x, times = 2L, dim = 2L)
#> AnvlArray
#> 1 1
#> 2 2
#> 3 3
#> [ CPUf32{3,2} ]
# Traced into the outer graph when composed with another jit-compiled function.
jit(function(x) prim_repeat_along(x, times = 2L, dim = 2L))(x)
#> AnvlArray
#> 1 1
#> 2 2
#> 3 3
#> [ CPUf32{3,2} ]And compute gradients through it.
If you want to contribute your primitive to {anvl}, there are some additional things to be aware of.
R/primitives.R: Define the
prim_* primitive via new_primitive()R/rules-stablehlo.R: Add the StableHLO
lowering ruleR/rules-reverse.R: Add the reverse
rule (if differentiable)R/rules-quickr.R: Add the quickr
lowering rule (optional; only if the primitive should run on the quickr
backend)R/api.R: Add the nv_*
wrapper function (or possibly in another api
file).Tests can go in two places:
inst/extra-tests/: For tests that
compare against torch. These live in inst/ to avoid listing
torch as a dependency.tests/testthat/: For tests without a
torch counterpart.Important: the describe() /
test_that() label must contain the full primitive name,
e.g. describe("prim_repeat_along", { ... }). The meta tests
in tests/testthat/test-primitives-meta.R verify that every
primitive has corresponding stablehlo and reverse tests, and flag any
that are missing.
Since no torch counterpart exists for prim_repeat_along,
we would add manual tests in:
tests/testthat/test-primitives-stablehlo.Rtests/testthat/test-primitives-reverse.RAlso, ensure that no linter errors are present,
devtools::check() passes, and format the code using
make format.
Higher-Order Primitives are primitives that parameterized by an R
function or expression. Examples include prim_if and
prim_while. These are generally much more complex to
handle, so we don’t cover them here in detail (for now). The general
idea, however, is that the primitive prim_* function needs
to trace the provided function using trace_fn() and then
forward this graph to the stablehlo lowering rule and reverse rule.