VAE: variational autoencoders

referência: 8.4-generating-images-with-vaes.Rmd

VAE são uma variante mais poderosa de AE:

em vez de compactar o input em um “código” fixo no espaço latente, ela transforma o input nos parâmetros de uma distribuição estatística: uma média e uma variância.

isso significa que estamos assumindo que a imagem de entrada foi gerada por um processo estatístico e que a aleatoriedade desse processo deve ser levada em consideração durante a codificação e decodificação.

O VAE usa os parâmetros de média e variância para amostrar aleatoriamente um elemento da distribuição e decodifica esse elemento tentando obter o input original.

O caráter estocástico deste processo melhora a robustez e força o espaço latente a codificar representações significativas.

Eis um resumo de como um VAE funciona:

  1. o encoder transforma o input em dois parâmetros no espaço latente, que designaremos como “z_mean” e “z_log_variance”.
  1. amostramos um ponto “z” de uma distribuição normal latente via z = z_mean + exp(z_log_variance) * epsilon onde epsilon é um tensor aleatório com valores pequenos. Como epsilon é pequeno, todo ponto próximo no espaço latente vai ser decodificado próximo ao input
  1. o decoder produz um output que busca replicar o input.

Uma VAE usa duas funções de custo:

  1. custo de reconstrução- força o output a reproduzir o input
  1. custo de regularização: ajuda a estruturar o espaço latente

Esquematicamente, uma VAE no Keras tem essa cara:

codifica o input com parâmetros de média e variância:

c(z_mean, z_log_variance) %<% encoder(input_img)

amostra um ponto latente com epsilon pequeno:

z <- z_mean + exp(z_log_variance) * epsilon

decodifica z de volta para uma imagem:

reconstructed_img <- decoder(z)

cria um modelo:

model <- keras_model(input_img, reconstructed_img)

treina o modelo com as duas funções de custo, de reconstrução e de regularização

encoder: uma convnet que mapeia o input em 2 vetores: z_mean e z_log_variance

# inicialização
t00 = Sys.time()     # para o tempo de processamento desse script
set.seed(1234)     # para reprodutibilidade

library(keras)

img_shape <- c(28, 28, 1)
batch_size <- 16
latent_dim <- 2L  # Dimensionality of the latent space: a plane

input_img <- layer_input(shape = img_shape)

x <- input_img %>% 
  layer_conv_2d(filters = 32, kernel_size = 3, padding = "same", 
                activation = "relu") %>% 
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same", 
                activation = "relu", strides = c(2, 2)) %>%
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same", 
                activation = "relu") %>%
  layer_conv_2d(filters = 64, kernel_size = 3, padding = "same", 
                activation = "relu") 

shape_before_flattening <- k_int_shape(x)

x <- x %>% 
  layer_flatten() %>% 
  layer_dense(units = 32, activation = "relu")

z_mean <- x %>% 
  layer_dense(units = latent_dim)

z_log_var <- x %>% 
  layer_dense(units = latent_dim)

amostra um ponto do espaço latente, z:

sampling <- function(args) {
  c(z_mean, z_log_var) %<-% args
  epsilon <- k_random_normal(shape = list(k_shape(z_mean)[1], latent_dim),
                             mean = 0, stddev = 1)
  z_mean + k_exp(z_log_var) * epsilon
}

z <- list(z_mean, z_log_var) %>% 
  layer_lambda(sampling)

decoder: reformata-se o vetor z como uma imagem que entra numa convnet com saída com as mesmas dimensões da imagem de input

# This is the input where we will feed `z`.
decoder_input <- layer_input(k_int_shape(z)[-1])

x <- decoder_input %>% 
  # Upsample to the correct number of units
  layer_dense(units = prod(as.integer(shape_before_flattening[-1])),
              activation = "relu") %>% 
  # Reshapes into an image of the same shape as before the last flatten layer
  layer_reshape(target_shape = shape_before_flattening[-1]) %>% 
  # Applies and then reverses the operation to the initial stack of 
  # convolution layers
  layer_conv_2d_transpose(filters = 32, kernel_size = 3, padding = "same",
                          activation = "relu", strides = c(2, 2)) %>%  
  layer_conv_2d(filters = 1, kernel_size = 3, padding = "same",
                activation = "sigmoid")  
  # We end up with a feature map of the same size as the original input.

# This is our decoder model.
decoder <- keras_model(decoder_input, x)

# We then apply it to `z` to recover the decoded `z`.
z_decoded <- decoder(z) 

Aqui define-se uma função para implementar as duas funções de custo:

library(R6)

CustomVariationalLayer <- R6Class("CustomVariationalLayer",
                                  
  inherit = KerasLayer,
  
  public = list(
    
    vae_loss = function(x, z_decoded) {
      x <- k_flatten(x)
      z_decoded <- k_flatten(z_decoded)
      xent_loss <- metric_binary_crossentropy(x, z_decoded)
      kl_loss <- -5e-4 * k_mean(
        1 + z_log_var - k_square(z_mean) - k_exp(z_log_var), 
        axis = -1L
      )
      k_mean(xent_loss + kl_loss)
    },
    
    call = function(inputs, mask = NULL) {
      x <- inputs[[1]]
      z_decoded <- inputs[[2]]
      loss <- self$vae_loss(x, z_decoded)
      self$add_loss(loss, inputs = inputs)
      x
    }
  )
)

layer_variational <- function(object) { 
  create_layer(CustomVariationalLayer, object, list())
} 

# Call the custom layer on the input and the decoded output to obtain
# the final model output
y <- list(input_img, z_decoded) %>% 
  layer_variational() 

inicializa e treina

não se passa ‘target’, só x_train (que é o target)

vae <- keras_model(input_img, y)

vae %>% compile(
  optimizer = "rmsprop",
  loss = NULL
)

# Trains the VAE on MNIST digits
mnist <- dataset_mnist() 
c(c(x_train, y_train), c(x_test, y_test)) %<-% mnist

x_train <- x_train / 255
x_train <- array_reshape(x_train, dim =c(dim(x_train), 1))

x_test <- x_test / 255
x_test <- array_reshape(x_test, dim =c(dim(x_test), 1))

vae %>% fit(
  x = x_train, y = NULL,
  epochs = 10,
  batch_size = batch_size,
  validation_data = list(x_test, NULL)
)

agora que o modelo está treinado, pode-se usar o decoder para transformar vetores arbitrários do espaço latente em imagens:

n <- 15            # Number of rows / columns of digits
digit_size <- 28   # Height / width of digits in pixels

# Transforms linearly spaced coordinates on the unit square through the inverse
# CDF (ppf) of the Gaussian to produce values of the latent variables z,
# because the prior of the latent space is Gaussian
grid_x <- qnorm(seq(0.05, 0.95, length.out = n))
grid_y <- qnorm(seq(0.05, 0.95, length.out = n))

op <- par(mfrow = c(n, n), mar = c(0,0,0,0), bg = "black")
for (i in 1:length(grid_x)) {
  yi <- grid_x[[i]]
  for (j in 1:length(grid_y)) {
    xi <- grid_y[[j]]
    z_sample <- matrix(c(xi, yi), nrow = 1, ncol = 2)
    z_sample <- t(replicate(batch_size, z_sample, simplify = "matrix"))
    x_decoded <- decoder %>% predict(z_sample, batch_size = batch_size)
    digit <- array_reshape(x_decoded[1,,,], dim = c(digit_size, digit_size))
    plot(as.raster(digit))
  }
}

par(op)

Sys.time()-t00
## Time difference of 27.05058 mins