| Title: | R Interface to PJRT |
|---|---|
| Description: | Provides an R interface to PJRT (Pluggable Jit RunTime), which allows you to run XLA or stableHLO programs on a variety of hardware backends including CPU, GPU, and TPU. |
| Authors: | Sebastian Fischer [cre, aut] (ORCID: <https://orcid.org/0000-0002-9609-3197>), Daniel Falbel [aut] (ORCID: <https://orcid.org/0009-0006-0143-2392>) |
| Maintainer: | Sebastian Fischer <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.4.0 |
| Built: | 2026-06-03 11:27:25 UTC |
| Source: | https://github.com/r-xla/pjrt |
Provides an R interface to PJRT (Pluggable Jit RunTime), which allows you to run XLA or stableHLO programs on a variety of hardware backends including CPU, GPU, and TPU.
Configuration options provided by XLA
XLA provides various configuration options, but their documentation is scattered across various websites. The options include:
TF_CPP_MIN_LOG_LEVEL: Logging level for PJRT C++ API:
0: shows info, warnings and errors
1: shows warnings and errors
2: shows errors
3: shows nothing
XLA_FLAGS: See the openxla website for
more information.
Configuration options provided by this package
PJRT_PLATFORM: Default platform to use, falls back to "cpu".
PJRT_PLUGIN_PATH_<PLATFORM>: Path to custom plugin library file for a specific
platform (e.g., PJRT_PLUGIN_PATH_CPU, PJRT_PLUGIN_PATH_CUDA,
PJRT_PLUGIN_PATH_METAL). If set, the package will use this path instead
of downloading the plugin.
PJRT_PLUGIN_URL_<PLATFORM>: URL to download plugin from for a specific
platform (e.g., PJRT_PLUGIN_URL_CPU, PJRT_PLUGIN_URL_CUDA,
PJRT_PLUGIN_URL_METAL). If set, overrides the default plugin download URL.
PJRT_ZML_ARTIFACT_VERSION: Version of ZML artifacts to download.
Only used when downloading plugins from zml/pjrt-artifacts.
PJRT_CPU_DEVICE_COUNT: The number of CPU devices to use. Defaults to 1.
This is primarily intended for testing purposes.
PJRT_CUDA_R_PACKAGE: Name of the R package providing CUDA libraries.
Defaults to the value of cuda12.8.
Set this to use a different CUDA toolkit package, but note that other
versions may not work with the XLA plugin.
PJRT_DEBUG: If set (to any non-empty value), enables verbose debug output
via cli::cli_inform().
The pjrt package itself is MIT-licensed. The CUDA backend dynamically
loads NVIDIA software which is not bundled with pjrt, but downloaded
from NVIDIA's official redistributable channels by the CUDA toolkit R
package (e.g. cuda12.8) at install time. Its use is governed by the
NVIDIA CUDA Toolkit EULA, with the
exception of cuDNN, which is covered by the
NVIDIA cuDNN SLA,
and NCCL, which is covered by its own license.
By installing or using the CUDA backend you accept those terms.
Maintainer: Sebastian Fischer [email protected] (ORCID)
Authors:
Daniel Falbel [email protected] (ORCID)
Useful links:
Report bugs at https://github.com/r-xla/pjrt/issues
Start an asynchronous transfer of buffer data from device to host.
Returns immediately with a PJRTArrayPromise object.
Use value() to get the R array (blocks if not ready).
Use is_ready() to check if transfer has completed (non-blocking).
as_array_async(x, ...)as_array_async(x, ...)
x |
A |
... |
Additional arguments (unused). |
A PJRTArrayPromise object. Call value() to get the R array.
as_array(), value(), is_ready(), pjrt_execute(), await()
buf <- pjrt_buffer(c(1.0, 2.0, 3.0, 4.0), shape = c(2, 2), dtype = "f32") result <- as_array_async(buf) is_ready(result) value(result)buf <- pjrt_buffer(c(1.0, 2.0, 3.0, 4.0), shape = c(2, 2), dtype = "f32") result <- as_array_async(buf) is_ready(result) value(result)
Transfer buffer data from device to host and return an R array.
## S3 method for class 'PJRTBuffer' as_array(x, check = FALSE, ...)## S3 method for class 'PJRTBuffer' as_array(x, check = FALSE, ...)
x |
( |
check |
(
No-op for float, boolean, and small/unsigned-32 integer dtypes —
|
... |
Additional arguments (unused). |
An R array (or vector for shape integer()).
Convert a platform name to a PJRT client or verify that an object is already a client.
as_pjrt_client(x)as_pjrt_client(x)
x |
( |
PJRTClient
# Convert from platform name client <- as_pjrt_client("cpu") client# Convert from platform name client <- as_pjrt_client("cpu") client
Convert a platform name or device to a PJRT device object.
as_pjrt_device(x)as_pjrt_device(x)
x |
( |
PJRTDevice
Convert a platform name to a PJRT plugin or verify that an object is already a plugin.
as_pjrt_plugin(x)as_pjrt_plugin(x)
x |
(any) |
PJRTPlugin
# Convert from platform name plugin <- as_pjrt_plugin("cpu") plugin# Convert from platform name plugin <- as_pjrt_plugin("cpu") plugin
PJRTElementType to stringGet a (lowercase) string representation of a PJRT element type
## S3 method for class 'PJRTElementType' as.character(x, ...)## S3 method for class 'PJRTElementType' as.character(x, ...)
x |
A PJRT element type object. |
... |
Additional arguments (unused). |
A string representation of the element type.
Block until the async operation is complete and return the object.
await(x, ...)await(x, ...)
x |
An async value object. |
... |
Additional arguments (unused). |
The awaited object (invisibly).
Copy a PJRTBuffer to a different device.
Returns a new buffer on the target device; the original buffer is unchanged.
If the buffer already lives in the requested device, no copy is performed.
When the target device belongs to a different client (e.g. copying from CPU to CUDA), the transfer is performed via a host roundtrip.
copy_buffer(buffer, device)copy_buffer(buffer, device)
buffer |
( |
device |
( |
A new PJRTBuffer on the target device.
buf <- pjrt_buffer(c(1, 2, 3), device = "cpu") buf2 <- copy_buffer(buf, "cuda") device(buf2)buf <- pjrt_buffer(c(1, 2, 3), device = "cpu") buf2 <- copy_buffer(buf, "cuda") device(buf2)
Get the addressable devices.
devices(x = NULL, ...)devices(x = NULL, ...)
x |
An object to get devices from: a |
... |
Additional arguments (currently unused). |
list of PJRTDevice
# Create client (defaults to CPU) client <- pjrt_client() devices(client)# Create client (defaults to CPU) client <- pjrt_client() devices(client)
Get the element type of a buffer.
elt_type(x)elt_type(x)
x |
( |
buf <- pjrt_buffer(c(1.0, 2.0, 3.0)) elt_type(buf)buf <- pjrt_buffer(c(1.0, 2.0, 3.0)) elt_type(buf)
Formats buffer data into a character vector of string representations of individual elements suitable for stableHLO.
format_buffer(buffer)format_buffer(buffer)
buffer |
( |
character() A character vector containing the formatted elements.
buf <- pjrt_buffer(c(1.5, 2.5, 3.5)) format_buffer(buf)buf <- pjrt_buffer(c(1.5, 2.5, 3.5)) format_buffer(buf)
Non-blocking check to see if an async operation has completed.
is_ready(x, ...)is_ready(x, ...)
x |
An async value object. |
... |
Additional arguments (unused). |
TRUE if the operation has completed, FALSE otherwise.
Create a PJRT Buffer from an R object.
Any numeric PJRT buffer is an array and 0-dimensional arrays are used as scalars.
pjrt_buffer will create a array with dimensions (1) for a vector of length 1, while
pjrt_scalar will create a 0-dimensional array for an R vector of length 1.
To create an empty buffer (at least one dimension must be 0), use pjrt_empty.
Important: No checks are performed when creating the buffer, so you need to ensure that the data fits the selected element type (e.g., to prevent buffer overflow) and that no NA values are present.
pjrt_buffer( data, dtype = NULL, device = NULL, shape = NULL, check = FALSE, ... ) pjrt_scalar(data, dtype = NULL, device = NULL, check = FALSE, ...) pjrt_empty(dtype, shape, device = NULL)pjrt_buffer( data, dtype = NULL, device = NULL, shape = NULL, check = FALSE, ... ) pjrt_scalar(data, dtype = NULL, device = NULL, check = FALSE, ...) pjrt_empty(dtype, shape, device = NULL)
data |
(any) |
dtype |
(
|
device |
( |
shape |
( |
check |
( |
... |
(any)
|
PJRTBuffer
platform() -> character(1): for the platform name of the buffer ("cpu", "cuda", ...).
device() -> PJRTDevice: for the device of the buffer (also includes device number)
elt_type() -> PJRTElementType: for the element type of the buffer.
shape() -> integer(): for the shape of the buffer.
as_array() -> array | vector: for converting back to R (vector is only used for shape integer()).
as_raw() -> raw for a raw vector.
safetensors::safe_save_file for writing to a safetensors file.
safetensors::safe_load_file for reading from a safetensors file.
When calling this function on a vector of length 1, the resulting shape is 1L.
To create a 0-dimensional buffer, use pjrt_scalar where the resulting shape is integer().
# Create a buffer from a numeric vector buf <- pjrt_buffer(c(1, 2, 3, 4)) buf # Create a buffer from a matrix mat <- matrix(1:6, nrow = 2) buf <- pjrt_buffer(mat) buf # Create an integer buffer from an array arr <- array(1:8, dim = c(2, 2, 2)) buf <- pjrt_buffer(arr) # Create a scalar (0-dimensional array) scalar <- pjrt_scalar(42, dtype = "f32") scalar # Create an empty buffer empty <- pjrt_empty(dtype = "f32", shape = c(0, 3)) empty# Create a buffer from a numeric vector buf <- pjrt_buffer(c(1, 2, 3, 4)) buf # Create a buffer from a matrix mat <- matrix(1:6, nrow = 2) buf <- pjrt_buffer(mat) buf # Create an integer buffer from an array arr <- array(1:8, dim = c(2, 2, 2)) buf <- pjrt_buffer(arr) # Create a scalar (0-dimensional array) scalar <- pjrt_scalar(42, dtype = "f32") scalar # Create an empty buffer empty <- pjrt_empty(dtype = "f32", shape = c(0, 3)) empty
Create a PJRT client for a specific device.
pjrt_client(platform = NULL, ...)pjrt_client(platform = NULL, ...)
platform |
( |
... |
Additional options passed to the PJRT client creation.
For CPU clients, you can pass |
PJRTClient
platform() for a character(1) representation of the platform.
devices() for a list of PJRTDevice objects.
# Create a client (defaults to CPU) client <- pjrt_client() client# Create a client (defaults to CPU) client <- pjrt_client() client
Compile a PJRTProgram program into a PJRTExecutable.
pjrt_compile(program, compile_options = new_compile_options(), device = NULL)pjrt_compile(program, compile_options = new_compile_options(), device = NULL)
program |
( |
compile_options |
( |
device |
( |
PJRTExecutable
# Create a simple program src <- r"( func.func @main(\%arg0: tensor<2xf32>) -> tensor<2xf32> { return \%arg0 : tensor<2xf32> } )" prog <- pjrt_program(src = src) exec <- pjrt_compile(prog)# Create a simple program src <- r"( func.func @main(\%arg0: tensor<2xf32>) -> tensor<2xf32> { return \%arg0 : tensor<2xf32> } )" prog <- pjrt_program(src = src) exec <- pjrt_compile(prog)
Create a PJRT Device from an R object.
pjrt_device(device)pjrt_device(device)
device |
(any) |
PJRTDevice
platform() for a character(1) representation of the platform.
# Show available devices for CPU client devices(pjrt_client("cpu")) # Create device 0 for CPU client dev <- pjrt_device("cpu:0") dev# Show available devices for CPU client devices(pjrt_client("cpu")) # Create device 0 for CPU client dev <- pjrt_device("cpu:0") dev
Execute a PJRT program with the given inputs and execution options.
Returns immediately with PJRTBuffer object(s) that may not be ready yet.
Important: Arguments are passed by position and names are ignored.
Inputs can be PJRTBuffer objects, including buffers that are not yet ready.
PJRT handles the dependencies internally.
Use await() to block until the result is ready.
Use is_ready() to check if execution has completed (non-blocking).
Use as_array_async() to chain async buffer-to-host transfer.
pjrt_execute(executable, ..., execution_options = NULL, simplify = TRUE)pjrt_execute(executable, ..., execution_options = NULL, simplify = TRUE)
executable |
( |
... |
( |
execution_options |
( |
simplify |
( |
PJRTBuffer | list of PJRTBuffers
await(), is_ready(), as_array_async()
# Create and compile a simple identity program src <- r"( func.func @main( \%x: tensor<2x2xf32>, \%y: tensor<2x2xf32> ) -> tensor<2x2xf32> { \%0 = "stablehlo.add"(\%x, \%y) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> "func.return"(\%0): (tensor<2x2xf32>) -> () } )" prog <- pjrt_program(src = src) exec <- pjrt_compile(prog) # Execute with input x <- pjrt_buffer(c(1.0, 2.0, 3.0, 4.0), shape = c(2, 2), dtype = "f32") y <- pjrt_buffer(c(5, 6, 7, 8), shape = c(2, 2), dtype = "f32") pjrt_execute(exec, x, y)# Create and compile a simple identity program src <- r"( func.func @main( \%x: tensor<2x2xf32>, \%y: tensor<2x2xf32> ) -> tensor<2x2xf32> { \%0 = "stablehlo.add"(\%x, \%y) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> "func.return"(\%0): (tensor<2x2xf32>) -> () } )" prog <- pjrt_program(src = src) exec <- pjrt_compile(prog) # Execute with input x <- pjrt_buffer(c(1.0, 2.0, 3.0, 4.0), shape = c(2, 2), dtype = "f32") y <- pjrt_buffer(c(5, 6, 7, 8), shape = c(2, 2), dtype = "f32") pjrt_execute(exec, x, y)
Create execution options for configuring how a PJRT program is executed,
including buffer donation settings.
Important:
It is not enough to only mark a buffer as donatable (not not donatable)
during runtime. The program also needs to specify this during compile-time
via input-output aliasing (stableHLO attribute tf.aliasing_output).
pjrt_execution_options(non_donatable_input_indices = integer(), launch_id = 0L)pjrt_execution_options(non_donatable_input_indices = integer(), launch_id = 0L)
non_donatable_input_indices |
( |
launch_id |
( |
PJRTExecuteOptions
# Create default execution options opts <- pjrt_execution_options() # Mark buffer 0 as non-donatable opts <- pjrt_execution_options(non_donatable_input_indices = 0L)# Create default execution options opts <- pjrt_execution_options() # Mark buffer 0 as non-donatable opts <- pjrt_execution_options(non_donatable_input_indices = 0L)
Create a PJRT plugin for a specific platform.
pjrt_plugin(platform)pjrt_plugin(platform)
platform |
( |
PJRTPlugin
plugin_attributes() -> list(): for the attributes of the plugin.
plugin <- pjrt_plugin("cpu") pluginplugin <- pjrt_plugin("cpu") plugin
PJRTProgram
Create a program from a string or file path.
pjrt_program(src = NULL, path = NULL, format = c("mlir", "hlo"))pjrt_program(src = NULL, path = NULL, format = c("mlir", "hlo"))
src |
( |
path |
( |
format |
( |
PJRTProgram
# Create a program from source src <- " func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %arg0 : tensor<2xf32> } " prog <- pjrt_program(src = src) prog# Create a program from source src <- " func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %arg0 : tensor<2xf32> } " prog <- pjrt_program(src = src) prog
Register an XLA FFI handler for use with stablehlo.custom_call.
Handlers are C/C++ functions defined using the XLA FFI API
(see xla/ffi/api/ffi.h shipped in pjrt's inst/include/).
They are passed to this function as external pointers.
Registration is deferred: if the PJRT plugin for a given platform
is not yet loaded, the handler is queued and registered automatically
when pjrt_plugin() loads it.
pjrt_register_custom_call(target_name, handler, .package = NULL)pjrt_register_custom_call(target_name, handler, .package = NULL)
target_name |
( |
handler |
A named list of external pointers ( |
.package |
( |
NULL (invisibly).
Get the platform name of a PJRT buffer.
platform(x, ...)platform(x, ...)
x |
( |
... |
Additional arguments (unused). |
character(1)
buf <- pjrt_buffer(c(1, 2, 3)) platform(buf)buf <- pjrt_buffer(c(1, 2, 3)) platform(buf)
Get the attributes of a PJRT plugin. This commonly includes:
xla_version
stablehlo_current_version
stablehlo_minimum_version
But the implementation depends on the plugin.
plugin_attributes(plugin)plugin_attributes(plugin)
plugin |
( |
named list()
plugin_attributes("cpu")plugin_attributes("cpu")
Create a PJRT client for a specific plugin and platform.
plugin_client_create(plugin, platform, options = list())plugin_client_create(plugin, platform, options = list())
plugin |
( |
platform |
( |
options |
( |
PJRTClient
Check if one more more plugin is already downloaded.
plugins_downloaded(platforms = NULL)plugins_downloaded(platforms = NULL)
platforms |
( |
logical(1)
# Check if CPU plugin is downloaded plugins_downloaded("cpu")# Check if CPU plugin is downloaded plugins_downloaded("cpu")
Print a PJRTBuffer.
## S3 method for class 'PJRTBuffer' print( x, max_rows = getOption("pjrt.print_max_rows", 30L), max_width = getOption("pjrt.print_max_width", 85L), max_rows_slice = getOption("pjrt.print_max_rows_slice", max_rows), header = TRUE, footer = NULL, ... )## S3 method for class 'PJRTBuffer' print( x, max_rows = getOption("pjrt.print_max_rows", 30L), max_width = getOption("pjrt.print_max_width", 85L), max_rows_slice = getOption("pjrt.print_max_rows_slice", max_rows), header = TRUE, footer = NULL, ... )
x |
( |
max_rows |
( |
max_width |
( |
max_rows_slice |
( |
header |
( |
footer |
( |
... |
Additional arguments (unused). |
Materialize and return the result of an async operation. Blocks until the operation is complete if it hasn't finished yet.
For PJRTArrayPromise, returns the materialized R array.
For PJRTBuffer, use await() to block until ready.
value(x, ...)value(x, ...)
x |
An async value object. |
... |
Additional arguments (unused). |
The materialized value.