Do Neural Networks Dream of Vectorized Sheep?
Google's Project Magenta is an endeavor to explore and push the boundaries of machine-generated art and music. It encompasses the use of machine learning to create new works, but also seeks to develop new assistive tools for human creators, and provides a place for creators to learn about and collaborate with these tools.
This article is a port of a tutorial, originally available here, on using the Sketch-RNN model provided by Magenta. Sketch-RNN is a generative model for vector drawings, which essentially learns how to reconstruct figures using a series of simulated 'motor actions'. For an overview of the model, see this blog post from David Ha. For further technical details, refer to this paper by David Ha and Douglas Eck.
Package installation and helper-function definitions are in Section 6, so here we can jump right in with our first dataset.
1. Training models on the Aaron Koblin Sheep Dataset
For more information on the dataset visit http://www.aaronkoblin.com/project/the-sheep-market.
Define the path of the model you want to load, and also the path of the dataset .
data_dir = 'http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/' models_root_dir = '/magenta-models/sketch_rnn/' model_dir = models_root_dir + 'aaron_sheep/layer_norm' from sheep import * from svgwrite.drawing import Drawing # Define how to interpret svgwrite Drawing objects Drawing._repr_svg_ = lambda self: self.tostring() np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) # Fixing random seeds to get consistent results for display random.seed(3) np.random.seed(3) [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(data_dir, model_dir)
Construct the Sketch-RNN model here, loading its weights from a downloaded checkpoint.
reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) load_checkpoint(sess, model_dir)
Get a sample drawing from the test set and render it to SVG.
stroke = test_set.random_sample() draw_strokes(stroke)
Let's try to encode the sample stroke into latent vector
z, drawing = encode(sess, eval_model, stroke) drawing
...and convert
decode(sess, eval_model, sample_model, z, temperature=0.65)
Create generated grid at various temperatures from 0.1 to 1.0.
stroke_list = [] for i in range(10): stroke_list.append([decode(sess, eval_model, sample_model, z, draw_mode=False, temperature=0.1*i+0.1), [0, i]]) stroke_grid = make_grid_svg(stroke_list) # save as a file for display up top draw_strokes(stroke_grid).saveas("/results/title.svg")
Now for a latent space interpolation example between two sheep,
z0 = z decode(sess, eval_model, sample_model, z0)
stroke = test_set.random_sample() z1, drawing = encode(sess, eval_model, stroke) drawing
decode(sess, eval_model, sample_model, z1)
Now we interpolate between sheep
z_list = [] N = 10 for t in np.linspace(0, 1, N): z_list.append(slerp(z0, z1, t)) # for every latent vector in z_list, sample a vector image reconstructions = [] for i in range(N): reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False), [0, i]]) stroke_grid = make_grid_svg(reconstructions) draw_strokes(stroke_grid)
2. Flamingo Model
Let's load the Flamingo Model, and try Unconditional (Decoder-Only) Generation .
model_dir = models_root_dir + 'flamingo/lstm_uncond' [hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
Construct & load the new model.
reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) load_checkpoint(sess, model_dir)
Randomly unconditionally generate 10 examples.
N = 10 reconstructions = [] for i in range(N): reconstructions.append([decode(sess, eval_model, sample_model, temperature=0.5, draw_mode=False), [0, i]]) stroke_grid = make_grid_svg(reconstructions) draw_strokes(stroke_grid)
3. Owl Model
Let's load the owl model, and generate two sketches using two random IID gaussian latent vectors .
model_dir = models_root_dir + 'owl/lstm' [hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
Construct & load.
random.seed(15); np.random.seed(15); reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) load_checkpoint(sess, model_dir)
z_0 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_0)
z_1 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_1)
Let's interpolate between the two owls
z_list = [] # interpolate spherically between z_0 and z_1 N = 10 for t in np.linspace(0, 1, N): z_list.append(slerp(z_0, z_1, t)) # for every latent vector in z_list, sample a vector image reconstructions = [] for i in range(N): reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.1), [0, i]]) stroke_grid = make_grid_svg(reconstructions) draw_strokes(stroke_grid)
4. Catbus!
Let's load the model trained on both cats and buses! Catbus!
model_dir = models_root_dir + 'catbus/lstm' [hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
Construct & load.
random.seed(1); np.random.seed(1); reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) load_checkpoint(sess, model_dir)
z_1 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_1)
z_0 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_0)
Let's interpolate between a cat and a bus!!!
z_list = [] # interpolate spherically between z_1 and z_0 N = 10 for t in np.linspace(0, 1, N): z_list.append(slerp(z_1, z_0, t)) # for every latent vector in z_list, sample a vector image reconstructions = [] for i in range(N): reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.15), [0, i]]) stroke_grid = make_grid_svg(reconstructions) draw_strokes(stroke_grid)
5. Mad Science!
Why stop here? Let's load the model trained on both elephants and pigs!!!
model_dir = models_root_dir + 'elephantpig/lstm' [hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
# Reset random seeds to get good examples random.seed(12); np.random.seed(12); reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) load_checkpoint(sess, model_dir)
z_0 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_0)
z_1 = np.random.randn(eval_model.hps.z_size) decode(sess, eval_model, sample_model, z_1)
Tribute to an episode of South Park: The interpolation between an Elephant and a Pig .
z_list = [] # interpolate spherically between z_1 and z_0 N = 10 for t in np.linspace(0, 1, N): z_list.append(slerp(z_0, z_1, t)) # for every latent vector in z_list, sample a vector image reconstructions = [] for i in range(N): reconstructions.append([decode(sess, eval_model, sample_model, z_list[i], draw_mode=False, temperature=0.15), [0, i]]) stroke_grid = make_grid_svg(reconstructions, grid_space_x=25.0) draw_strokes(stroke_grid, factor=0.3)
6. Setup
6.1. Installation
Install as many deps as we can via conda (versions pulled from magenta setup.py). We'll use the dev-version of magenta off Github for Python 3 compatibility.
conda install absl-py backports.tempfile click flask future gevent greenlet gunicorn itsdangerous mpmath networkx pyasn1 pyasn1-modules sympy wheel wrapt \ 'bokeh>=0.12.0' 'joblib>=0.12' 'matplotlib>=1.5.3' 'numpy>=1.14.6,<=1.16.2' 'pandas>=0.18.1' 'Pillow>=3.4.2' 'scipy>=0.18.1,<=1.2.0' conda install -c conda-forge \ cachetools google-api-python-client googleapis-common-protos google-auth google-auth-httplib2 httplib2 oauth2client promise pyglet rsa sk-video svgwrite uritemplate \ 'intervaltree>=2.1.0' 'librosa>=0.6.2' 'protobuf>=3.7,<4' 'sox>=1.3.7' 'tensorflow-probability>=0.5.0' conda clean -qtipy ldconfig
pip install git+https://github.com/tensorflow/magenta@master
Download models.
# From the Magenta Dockerfile # github.com/tensorflow/magenta/blob/master/magenta/tools/docker/Dockerfile # Not sure if really needed for this, but shows how to do it anyway. mkdir -p /magenta-data mkdir -p /magenta-models cd /magenta-models wget -q \ http://download.magenta.tensorflow.org/models/attention_rnn.mag \ http://download.magenta.tensorflow.org/models/basic_rnn.mag \ http://download.magenta.tensorflow.org/models/chord_pitches_improv.mag \ http://download.magenta.tensorflow.org/models/drum_kit_rnn.mag \ http://download.magenta.tensorflow.org/models/lookback_rnn.mag \ http://download.magenta.tensorflow.org/models/polyphony_rnn.mag \ http://download.magenta.tensorflow.org/models/rl_rnn.mag \ http://download.magenta.tensorflow.org/models/rl_tuner_note_rnn.ckpt \ http://download.magenta.tensorflow.org/models/multistyle-pastiche-generator-monet.ckpt \ http://download.magenta.tensorflow.org/models/multistyle-pastiche-generator-varied.ckpt # Magenta wants the py2.7 urllib python -c 'from magenta.models.sketch_rnn.sketch_rnn_train import * download_pretrained_models(models_root_dir="/magenta-models/sketch_rnn")'
Replace 1/0 with true/false in model config files (required for Tensorflow 1.5+).
cd /magenta-models/sketch_rnn for i in */*/model_config.json; do perl -pi -e 's/((?:use|is)_.*|conditional)": 0/$1": false/; s/((?:use|is)_.*|conditional)": 1/$1": true/' $i done
6.2. Functions
Create a small module to define functions and make package loading easy. In the runtime config we mount this named Code Listing somewhere, and set PYTHONPATH
to point at the same place.
import sys, os sys.version_info import numpy as np import time import random # import cPickle import codecs import collections import os import math import json import tensorflow as tf tf.__version__ from six.moves import xrange # MAGENTA STUFF import magenta magenta.__version__ from magenta.models.sketch_rnn.sketch_rnn_train import * from magenta.models.sketch_rnn.model import * from magenta.models.sketch_rnn.utils import * from magenta.models.sketch_rnn.rnn import * import svgwrite # A function to generate and return SVG images, # and another to put together a grid of many such images. def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) min_x, max_x, min_y, max_y = get_bounds(data, factor) dims = (50 + max_x - min_x, 50 + max_y - min_y) dwg = svgwrite.Drawing(svg_filename, size=dims) dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white')) lift_pen = 1 abs_x = 25 - min_x abs_y = 25 - min_y p = "M%s,%s " % (abs_x, abs_y) command = "m" for i in xrange(len(data)): if (lift_pen == 1): command = "m" elif (command != "l"): command = "l" else: command = "" x = float(data[i,0])/factor y = float(data[i,1])/factor lift_pen = data[i, 2] p += command+str(x)+","+str(y)+" " the_color = "black" stroke_width = 1 dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none")) return dwg def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0): def get_start_and_end(x): x = np.array(x) x = x[:, 0:2] x_start = x[0] x_end = x.sum(axis=0) x = x.cumsum(axis=0) x_max = x.max(axis=0) x_min = x.min(axis=0) center_loc = (x_max+x_min)*0.5 return x_start-center_loc, x_end x_pos = 0.0 y_pos = 0.0 result = [[x_pos, y_pos, 1]] for sample in s_list: s = sample[0] grid_loc = sample[1] grid_y = grid_loc[0]*grid_space+grid_space*0.5 grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5 start_loc, delta_pos = get_start_and_end(s) loc_x = start_loc[0] loc_y = start_loc[1] new_x_pos = grid_x+loc_x new_y_pos = grid_y+loc_y result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0]) result += s.tolist() result[-1][2] = 1 x_pos = new_x_pos+delta_pos[0] y_pos = new_y_pos+delta_pos[1] return np.array(result) # Two more convenience functions, to encode a stroke into a latent vector, # and decode from latent vector to stroke. def encode(sess, eval_model, input_strokes): strokes = to_big_strokes(input_strokes).tolist() strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] drawing = draw_strokes(to_normal_strokes(np.array(strokes))) stroke = sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] return (stroke, drawing) def decode(sess, eval_model, sample_model, z_input=None, draw_mode=True, temperature=0.1, factor=0.2): z = None if z_input is not None: z = [z_input] sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) strokes = to_normal_strokes(sample_strokes) if draw_mode: return draw_strokes(strokes, factor) return strokes