Do Neural Networks Dream of Vectorized Sheep?

Google's Project Magenta is an endeavor to explore and push the boundaries of machine-generated art and music. It encompasses the use of machine learning to create new works, but also seeks to develop new assistive tools for human creators, and provides a place for creators to learn about and collaborate with these tools.

This article is a port of a tutorial, originally available here, on using the Sketch-RNN model provided by Magenta. Sketch-RNN is a generative model for vector drawings, which essentially learns how to reconstruct figures using a series of simulated 'motor actions'. For an overview of the model, see this blog post from David Ha. For further technical details, refer to this paper by David Ha and Douglas Eck.

Package installation and helper-function definitions are in Section 6, so here we can jump right in with our first dataset.

1. Training models on the Aaron Koblin Sheep Dataset

For more information on the dataset visit http://www.aaronkoblin.com/project/the-sheep-market.

Define the path of the model you want to load, and also the path of the dataset .

data_dir = 'http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/'
models_root_dir = '/magenta-models/sketch_rnn/'
model_dir = models_root_dir + 'aaron_sheep/layer_norm'

from sheep import *
from svgwrite.drawing import Drawing

# Define how to interpret svgwrite Drawing objects
Drawing._repr_svg_ = lambda self: self.tostring()

np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)

# Fixing random seeds to get consistent results for display
random.seed(3)
np.random.seed(3)

[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(data_dir, model_dir)

Construct the Sketch-RNN model here, loading its weights from a downloaded checkpoint.

reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

load_checkpoint(sess, model_dir)

Get a sample drawing from the test set and render it to SVG.

stroke = test_set.random_sample()
draw_strokes(stroke) 

Let's try to encode the sample stroke into latent vector ...

z, drawing = encode(sess, eval_model, stroke)
drawing

...and convert back to a drawing, at a temperature of 0.65.

decode(sess, eval_model, sample_model, z, temperature=0.65)

Create generated grid at various temperatures from 0.1 to 1.0.

stroke_list = []
for i in range(10):
  stroke_list.append([decode(sess, eval_model, sample_model, z, draw_mode=False, temperature=0.1*i+0.1), [0, i]])
stroke_grid = make_grid_svg(stroke_list)
# save as a file for display up top
draw_strokes(stroke_grid).saveas("/results/title.svg")

Now for a latent space interpolation example between two sheep, and . We'll just use our prior sheep for , and a new random one for .

z0 = z
decode(sess, eval_model, sample_model, z0)
stroke = test_set.random_sample()
z1, drawing = encode(sess, eval_model, stroke)
drawing
decode(sess, eval_model, sample_model, z1)

Now we interpolate between sheep and sheep .

z_list = []
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z0, z1, t))
  
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False), [0, i]])
  
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

2. Flamingo Model

Let's load the Flamingo Model, and try Unconditional (Decoder-Only) Generation .

model_dir = models_root_dir + 'flamingo/lstm_uncond'
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)

Construct & load the new model.

reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

load_checkpoint(sess, model_dir)

Randomly unconditionally generate 10 examples.

N = 10
reconstructions = []
for i in range(N):
  reconstructions.append([decode(sess, eval_model, sample_model, temperature=0.5, draw_mode=False), [0, i]])
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

3. Owl Model

Let's load the owl model, and generate two sketches using two random IID gaussian latent vectors .

model_dir = models_root_dir + 'owl/lstm'
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)

Construct & load.

random.seed(15); np.random.seed(15); reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

load_checkpoint(sess, model_dir)
z_0 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_0)
z_1 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_1)

Let's interpolate between the two owls and

z_list = [] # interpolate spherically between z_0 and z_1
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.1), [0, i]])

stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

4. Catbus!

Let's load the model trained on both cats and buses! Catbus!

model_dir = models_root_dir + 'catbus/lstm'
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)

Construct & load.

random.seed(1); np.random.seed(1); reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

load_checkpoint(sess, model_dir)
z_1 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_1)
z_0 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_0)

Let's interpolate between a cat and a bus!!!

z_list = [] # interpolate spherically between z_1 and z_0
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_1, z_0, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.15), [0, i]])

stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

5. Mad Science!

Why stop here? Let's load the model trained on both elephants and pigs!!!

model_dir = models_root_dir + 'elephantpig/lstm'
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
# Reset random seeds to get good examples
random.seed(12); np.random.seed(12); reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

load_checkpoint(sess, model_dir)
z_0 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_0)
z_1 = np.random.randn(eval_model.hps.z_size)
decode(sess, eval_model, sample_model, z_1)

Tribute to an episode of South Park: The interpolation between an Elephant and a Pig .

z_list = [] # interpolate spherically between z_1 and z_0
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.15), [0, i]])

stroke_grid = make_grid_svg(reconstructions, grid_space_x=25.0)
draw_strokes(stroke_grid, factor=0.3)

6. Setup

6.1. Installation

Install as many deps as we can via conda (versions pulled from magenta setup.py). We'll use the dev-version of magenta off Github for Python 3 compatibility.

conda install absl-py backports.tempfile click flask future gevent greenlet gunicorn itsdangerous mpmath networkx pyasn1 pyasn1-modules sympy wheel wrapt \
  'bokeh>=0.12.0' 'joblib>=0.12' 'matplotlib>=1.5.3' 'numpy>=1.14.6,<=1.16.2' 'pandas>=0.18.1' 'Pillow>=3.4.2' 'scipy>=0.18.1,<=1.2.0'

conda install -c conda-forge \
  cachetools google-api-python-client googleapis-common-protos google-auth google-auth-httplib2 httplib2 oauth2client promise pyglet rsa sk-video svgwrite uritemplate \
  'intervaltree>=2.1.0' 'librosa>=0.6.2' 'protobuf>=3.7,<4' 'sox>=1.3.7' 'tensorflow-probability>=0.5.0' 

conda clean -qtipy

ldconfig
pip install git+https://github.com/tensorflow/magenta@master

Download models.

# From the Magenta Dockerfile
# github.com/tensorflow/magenta/blob/master/magenta/tools/docker/Dockerfile
# Not sure if really needed for this, but shows how to do it anyway.

mkdir -p /magenta-data
mkdir -p /magenta-models
cd /magenta-models
wget -q \
  http://download.magenta.tensorflow.org/models/attention_rnn.mag \
  http://download.magenta.tensorflow.org/models/basic_rnn.mag \
  http://download.magenta.tensorflow.org/models/chord_pitches_improv.mag \
  http://download.magenta.tensorflow.org/models/drum_kit_rnn.mag \
  http://download.magenta.tensorflow.org/models/lookback_rnn.mag \
  http://download.magenta.tensorflow.org/models/polyphony_rnn.mag \
  http://download.magenta.tensorflow.org/models/rl_rnn.mag \
  http://download.magenta.tensorflow.org/models/rl_tuner_note_rnn.ckpt \
  http://download.magenta.tensorflow.org/models/multistyle-pastiche-generator-monet.ckpt \
  http://download.magenta.tensorflow.org/models/multistyle-pastiche-generator-varied.ckpt

# Magenta wants the py2.7 urllib
python -c 'from magenta.models.sketch_rnn.sketch_rnn_train import *
download_pretrained_models(models_root_dir="/magenta-models/sketch_rnn")'

Replace 1/0 with true/false in model config files (required for Tensorflow 1.5+).

cd /magenta-models/sketch_rnn
for i in */*/model_config.json; do
  perl -pi -e 's/((?:use|is)_.*|conditional)": 0/$1": false/; s/((?:use|is)_.*|conditional)": 1/$1": true/' $i
done

6.2. Functions

Create a small module to define functions and make package loading easy. In the runtime config we mount this named Code Listing somewhere, and set PYTHONPATH to point at the same place.

import sys, os
sys.version_info
import numpy as np
import time
import random
# import cPickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
tf.__version__
from six.moves import xrange

# MAGENTA STUFF
import magenta
magenta.__version__

from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *

import svgwrite

# A function to generate and return SVG images, 
# and another to put together a grid of many such images.

def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
  tf.gfile.MakeDirs(os.path.dirname(svg_filename))
  min_x, max_x, min_y, max_y = get_bounds(data, factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  lift_pen = 1
  abs_x = 25 - min_x 
  abs_y = 25 - min_y
  p = "M%s,%s " % (abs_x, abs_y)
  command = "m"
  for i in xrange(len(data)):
    if (lift_pen == 1):
      command = "m"
    elif (command != "l"):
      command = "l"
    else:
      command = ""
    x = float(data[i,0])/factor
    y = float(data[i,1])/factor
    lift_pen = data[i, 2]
    p += command+str(x)+","+str(y)+" "
  the_color = "black"
  stroke_width = 1
  dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
  return dwg

def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
  def get_start_and_end(x):
    x = np.array(x)
    x = x[:, 0:2]
    x_start = x[0]
    x_end = x.sum(axis=0)
    x = x.cumsum(axis=0)
    x_max = x.max(axis=0)
    x_min = x.min(axis=0)
    center_loc = (x_max+x_min)*0.5
    return x_start-center_loc, x_end
  x_pos = 0.0
  y_pos = 0.0
  result = [[x_pos, y_pos, 1]]
  for sample in s_list:
    s = sample[0]
    grid_loc = sample[1]
    grid_y = grid_loc[0]*grid_space+grid_space*0.5
    grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
    start_loc, delta_pos = get_start_and_end(s)

    loc_x = start_loc[0]
    loc_y = start_loc[1]
    new_x_pos = grid_x+loc_x
    new_y_pos = grid_y+loc_y
    result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])

    result += s.tolist()
    result[-1][2] = 1
    x_pos = new_x_pos+delta_pos[0]
    y_pos = new_y_pos+delta_pos[1]
  return np.array(result)

# Two more convenience functions, to encode a stroke into a latent vector, 
# and decode from latent vector to stroke.

def encode(sess, eval_model, input_strokes):
  strokes = to_big_strokes(input_strokes).tolist()
  strokes.insert(0, [0, 0, 1, 0, 0])
  seq_len = [len(input_strokes)]
  drawing = draw_strokes(to_normal_strokes(np.array(strokes)))
  stroke = sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: 	[strokes], eval_model.sequence_lengths: seq_len})[0]
  return (stroke, drawing)

def decode(sess, eval_model, sample_model, z_input=None, draw_mode=True, temperature=0.1, factor=0.2):
  z = None
  if z_input is not None:
    z = [z_input]
  sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
  strokes = to_normal_strokes(sample_strokes)
  if draw_mode:
    return draw_strokes(strokes, factor)
  return strokes
sheep.py
Python