18.5 Walking AST with recursive functions

To conclude the chapter I’m going to use everything you’ve learned about ASTs to solve more complicated problems. The inspiration comes from the base codetools package, which provides two interesting functions:

  • findGlobals() locates all global variables used by a function. This can be useful if you want to check that your function doesn’t inadvertently rely on variables defined in their parent environment.

  • checkUsage() checks for a range of common problems including unused local variables, unused parameters, and the use of partial argument matching.

Getting all of the details of these functions correct is fiddly, so we won’t fully develop the ideas. Instead we’ll focus on the big underlying idea: recursion on the AST. Recursive functions are a natural fit to tree-like data structures because a recursive function is made up of two parts that correspond to the two parts of the tree:

  • The recursive case handles the nodes in the tree. Typically, you’ll do something to each child of a node, usually calling the recursive function again, and then combine the results back together again. For expressions, you’ll need to handle calls and pairlists (function arguments).

  • The base case handles the leaves of the tree. The base cases ensure that the function eventually terminates, by solving the simplest cases directly. For expressions, you need to handle symbols and constants in the base case.

To make this pattern easier to see, we’ll need two helper functions. First we define expr_type() which will return “constant” for constant, “symbol” for symbols, “call”, for calls, “pairlist” for pairlists, and the “type” of anything else:

expr_type <- function(x) {
  if (rlang::is_syntactic_literal(x)) {
    "constant"
  } else if (is.symbol(x)) {
    "symbol"
  } else if (is.call(x)) {
    "call"
  } else if (is.pairlist(x)) {
    "pairlist"
  } else {
    typeof(x)
  }
}

expr_type(expr("a"))
#> [1] "constant"
expr_type(expr(x))
#> [1] "symbol"
expr_type(expr(f(1, 2)))
#> [1] "call"

We’ll couple this with a wrapper around the switch function:

switch_expr <- function(x, ...) {
  switch(expr_type(x),
    ...,
    stop("Don't know how to handle type ", typeof(x), call. = FALSE)
  )
}

With these two functions in hand, we can write a basic template for any function that walks the AST using switch() (Section 5.2.3):

recurse_call <- function(x) {
  switch_expr(x,
    # Base cases
    symbol = ,
    constant = ,

    # Recursive cases
    call = ,
    pairlist =
  )
}

Typically, solving the base case is easy, so we’ll do that first, then check the results. The recursive cases are trickier, and will often require some functional programming.

18.5.1 Finding F and T

We’ll start with a function that determines whether another function uses the logical abbreviations T and F because using them is often considered to be poor coding practice. Our goal is to return TRUE if the input contains a logical abbreviation, and FALSE otherwise.

Let’s first find the type of T versus TRUE:

expr_type(expr(TRUE))
#> [1] "constant"

expr_type(expr(T))
#> [1] "symbol"

TRUE is parsed as a logical vector of length one, while T is parsed as a name. This tells us how to write our base cases for the recursive function: a constant is never a logical abbreviation, and a symbol is an abbreviation if it’s “F” or “T”:

logical_abbr_rec <- function(x) {
  switch_expr(x,
    constant = FALSE,
    symbol = as_string(x) %in% c("F", "T")
  )
}

logical_abbr_rec(expr(TRUE))
#> [1] FALSE
logical_abbr_rec(expr(T))
#> [1] TRUE

I’ve written logical_abbr_rec() function assuming that the input will be an expression as this will make the recursive operation simpler. However, when writing a recursive function it’s common to write a wrapper that provides defaults or makes the function a little easier to use. Here we’ll typically make a wrapper that quotes its input (we’ll learn more about that in the next chapter), so we don’t need to use expr() every time.

logical_abbr <- function(x) {
  logical_abbr_rec(enexpr(x))
}

logical_abbr(T)
#> [1] TRUE
logical_abbr(FALSE)
#> [1] FALSE

Next we need to implement the recursive cases. Here we want to do the same thing for calls and for pairlists: recursively apply the function to each subcomponent, and return TRUE if any subcomponent contains a logical abbreviation. This is made easy by purrr::some(), which iterates over a list and returns TRUE if the predicate function is true for any element.

logical_abbr_rec <- function(x) {
  switch_expr(x,
    # Base cases
    constant = FALSE,
    symbol = as_string(x) %in% c("F", "T"),

    # Recursive cases
    call = ,
    pairlist = purrr::some(x, logical_abbr_rec)
  )
}

logical_abbr(mean(x, na.rm = T))
#> [1] TRUE
logical_abbr(function(x, na.rm = T) FALSE)
#> [1] TRUE

18.5.2 Finding all variables created by assignment

logical_abbr() is relatively simple: it only returns a single TRUE or FALSE. The next task, listing all variables created by assignment, is a little more complicated. We’ll start simply, and then make the function progressively more rigorous.

We start by looking at the AST for assignment:

ast(x <- 10)
#> █─`<-` 
#> ├─x 
#> └─10

Assignment is a call object where the first element is the symbol <-, the second is the name of variable, and the third is the value to be assigned.

Next, we need to decide what data structure we’re going to use for the results. Here I think it will be easiest if we return a character vector. If we return symbols, we’ll need to use a list() and that makes things a little more complicated.

With that in hand we can start by implementing the base cases and providing a helpful wrapper around the recursive function. Here the base cases are straightforward because we know that neither a symbol nor a constant represents assignment.

find_assign_rec <- function(x) {
  switch_expr(x,
    constant = ,
    symbol = character()
  )
}
find_assign <- function(x) find_assign_rec(enexpr(x))

find_assign("x")
#> character(0)
find_assign(x)
#> character(0)

Next we implement the recursive cases. This is made easier by a function that should exist in purrr, but currently doesn’t. flat_map_chr() expects .f to return a character vector of arbitrary length, and flattens all results into a single character vector.

flat_map_chr <- function(.x, .f, ...) {
  purrr::flatten_chr(purrr::map(.x, .f, ...))
}

flat_map_chr(letters[1:3], ~ rep(., sample(3, 1)))
#> [1] "a" "b" "b" "b" "c" "c" "c"

The recursive case for pairlists is straightforward: we iterate over every element of the pairlist (i.e. each function argument) and combine the results. The case for calls is a little bit more complex: if this is a call to <- then we should return the second element of the call:

find_assign_rec <- function(x) {
  switch_expr(x,
    # Base cases
    constant = ,
    symbol = character(),

    # Recursive cases
    pairlist = flat_map_chr(as.list(x), find_assign_rec),
    call = {
      if (is_call(x, "<-")) {
        as_string(x[[2]])
      } else {
        flat_map_chr(as.list(x), find_assign_rec)
      }
    }
  )
}

find_assign(a <- 1)
#> [1] "a"
find_assign({
  a <- 1
  {
    b <- 2
  }
})
#> [1] "a" "b"

Now we need to make our function more robust by coming up with examples intended to break it. What happens when we assign to the same variable multiple times?

find_assign({
  a <- 1
  a <- 2
})
#> [1] "a" "a"

It’s easiest to fix this at the level of the wrapper function:

find_assign <- function(x) unique(find_assign_rec(enexpr(x)))

find_assign({
  a <- 1
  a <- 2
})
#> [1] "a"

What happens if we have nested calls to <-? Currently we only return the first. That’s because when <- occurs we immediately terminate recursion.

find_assign({
  a <- b <- c <- 1
})
#> [1] "a"

Instead we need to take a more rigorous approach. I think it’s best to keep the recursive function focused on the tree structure, so I’m going to extract out find_assign_call() into a separate function.

find_assign_call <- function(x) {
  if (is_call(x, "<-") && is_symbol(x[[2]])) {
    lhs <- as_string(x[[2]])
    children <- as.list(x)[-1]
  } else {
    lhs <- character()
    children <- as.list(x)
  }

  c(lhs, flat_map_chr(children, find_assign_rec))
}

find_assign_rec <- function(x) {
  switch_expr(x,
    # Base cases
    constant = ,
    symbol = character(),

    # Recursive cases
    pairlist = flat_map_chr(x, find_assign_rec),
    call = find_assign_call(x)
  )
}

find_assign(a <- b <- c <- 1)
#> [1] "a" "b" "c"
find_assign(system.time(x <- print(y <- 5)))
#> [1] "x" "y"

The complete version of this function is quite complicated, it’s important to remember we wrote it by working our way up by writing simple component parts.

18.5.3 Exercises

  1. logical_abbr() returns TRUE for T(1, 2, 3). How could you modify logical_abbr_rec() so that it ignores function calls that use T or F?

  2. logical_abbr() works with expressions. It currently fails when you give it a function. Why? How could you modify logical_abbr() to make it work? What components of a function will you need to recurse over?

    logical_abbr(function(x = TRUE) {
      g(x + T)
    })
  3. Modify find_assign to also detect assignment using replacement functions, i.e. names(x) <- y.

  4. Write a function that extracts all calls to a specified function.