--- title: "Type Promotion" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Type Promotion} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, echo = FALSE} knitr::opts_chunk$set( eval = Sys.getenv("PJRT_PLATFORM", "cpu") != "metal" ) ``` ## Type Promotion Rules When combining arrays of different types (e.g., adding an `f32` to an `i32`), {anvl} needs to determine a common type. For example, below we are adding an `f32` to an `f64`, where the former is promoted to the latter's type, because it's more expressive. ```{r} library(anvl) jit(nv_add)( nv_scalar(1.0, dtype = "f32"), nv_scalar(1.0, dtype = "f64") ) ``` The type-promotion rules are inspired by JAX, and they are designed for execution on accelerators like GPUs, where one often wants speed instead of precision. The rules are defined by the `common_dtype()` function. It returns a `list()` with two values: the common dtype and a flag indicating whether the result is ambiguous, which we will cover later. ```{r} common_dtype("f64", "f32")$dtype common_dtype("i64", "f32")$dtype ``` A table with the promotion rules is below. ```{r, echo = FALSE} library(stablehlo) dtypes <- c("bool", "i8", "i16", "i32", "i64", "ui8", "ui16", "ui32", "ui64", "f32", "f64") tbl <- matrix(NA_character_, length(dtypes), length(dtypes)) for (i in seq_along(dtypes)) { for (j in seq_along(dtypes)) { tbl[i, j] <- repr(common_dtype(dtypes[i], dtypes[j])$dtype) } } rownames(tbl) <- dtypes colnames(tbl) <- dtypes knitr::kable(tbl, format = "html", caption = "Type promotion rules (row × column)") ``` ## Literals as Ambiguous Types Usually, the types in an {anvl} program can be deterministically inferred from the input types. The only case where this is not possible is when you use R literals. The default types for literals are as follows: * `double()` -> `f32` * `integer()` -> `i32` * `logical()` -> `i1` (bool) ```{r} jit(\() list(1L, 1.0, TRUE))() ``` However, because this is just a guess, they behave differently than known types during promotion. Therefore, the `common_dtype` function has two arguments indicating which of the data types are ambiguous. Below, the first type is a known `f64` and the second is an ambiguous `f32`. Within anvl, we denote the latter as `i32?`. The result is an `f64`, although we would promote to an `f64` if both were known. If both types are ambiguous, the result is generally the same as if both were known. ```{r} common_dtype("f32", "f64", FALSE, TRUE) common_dtype("f32", "f64", TRUE, TRUE) common_dtype("f32", "f64", FALSE, FALSE) ``` The promotion rules only change when one type is ambiguous and the other is not. There, we usually promote the ambiguous type to the known type, unless: 1. The ambiguous type is a float and the known type is not. 2. The known type is a bool but the ambiguous type is not. In both case, we promote the known type to the default type of the ambiguous type. The table below shows the promotion rules, where the rows are ambiguous and the columns are known. ```{r, echo = FALSE} dtypes <- c("bool", "i8", "i16", "i32", "i64", "ui8", "ui16", "ui32", "ui64", "f32", "f64") tbl <- matrix(NA_character_, length(dtypes), length(dtypes)) for (i in seq_along(dtypes)) { for (j in seq_along(dtypes)) { tbl[i, j] <- repr(common_dtype(dtypes[i], dtypes[j], TRUE, FALSE)$dtype) } } rownames(tbl) <- dtypes colnames(tbl) <- dtypes knitr::kable(tbl, format = "html", caption = "Promotion rules: ambiguous (row) × known (column)") ``` ## Creating Tensors with Different Ambiguity Both `nv_scalar()` and `nv_array()` create **non-ambiguous** arrays by default. You can explicitly control ambiguity using the `ambiguous` parameter: ```{r} s1 <- nv_scalar(1.0) ambiguous(s1) s2 <- nv_scalar(1.0, ambiguous = TRUE) ambiguous(s2) t1 <- nv_array(c(1.0, 2.0, 3.0)) ambiguous(t1) t2 <- nv_array(c(1.0, 2.0, 3.0), ambiguous = TRUE) ambiguous(t2) ``` ## Propagating Ambiguity Ambiguity is propagated through operations. Consider the following example: ```{r} f <- jit(function(x, y) { z <- x + 1L z * y }) f(nv_scalar(TRUE), nv_scalar(2L, dtype = "i16")) ``` The type of `z` is `i32?`, because `x` is promoted to an `i32`, the default type of the `1L` literal. If `z` was not ambiguous, the output would be an `i32`, because the `y` would be promoted to an `i32` in the multiplication. Because we propagate the ambiguity, the `z` is actually down-promoted to an `i16`, because the `z` is ambiguous, while the `y` is known.