Nextjournal / Sep 16 2019
MXNet GANs with Flans
1. Setup
1.1. Functions
We'll put functions here for later re-use via code imports. This can be run, to test code validity.
From viz.clj
.
6.4s
(ns mxnet-gan-flan.viz (:require [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.io :as mx-io]) (:import (nu.pattern OpenCV) (org.opencv.core Core CvType Mat Size) (org.opencv.imgproc Imgproc) (org.opencv.highgui Highgui))) ;;; Viz stuff (OpenCV/loadShared) (defn clip [x] (->> x (mapv (* 255 %)) (mapv (cond (< % 0) 0 (> % 255) 255 :else (int %))) (mapv (.byteValue %)))) (defn get-img [raw-data channels height width flip] (let [totals (* height width) img (if (> channels 1) ;; rgb image (let [[ra ga ba] (doall (partition totals raw-data)) rr (new Mat height width (CvType/CV_8U)) gg (new Mat height width (CvType/CV_8U)) bb (new Mat height width (CvType/CV_8U)) result (new Mat height width (CvType/CV_8U))] (do (.put rr 0 0 (byte-array ra)) (.put gg 0 0 (byte-array ga)) (.put bb 0 0 (byte-array ba))) (Core/merge (java.util.ArrayList. [bb gg rr]) result) result) ;; gray image (let [result (new Mat height width (CvType/CV_8U)) _ (.put result (int 0) (int 0) (byte-array raw-data))] result))] (do (if flip (let [result (new Mat) _ (Core/flip img result (int 0))] result) img)))) (defn im-sav [{:keys [title output-path x flip] :or {flip false} :as g-mod}] (let [shape (mx-shape/->vec (ndarray/shape x)) _ (assert (== 4 (count shape))) [n c h w] shape totals (* h w) raw-data (byte-array (clip (ndarray/to-array x))) row (.intValue (Math/sqrt n)) col row line-arrs (into [] (partition (* col c totals) raw-data)) line-mats (mapv (fn [line] (let [img-arr (into [] (partition (* c totals) line)) col-mats (new Mat) src (mapv (fn [arr] (get-img (into [] arr) c h w flip)) img-arr) _ (Core/hconcat (java.util.ArrayList. src) col-mats)] col-mats)) line-arrs) result (new Mat) resized-img (new Mat) _ (Core/vconcat (java.util.ArrayList. line-mats) result)] (do (Imgproc/resize result resized-img (new Size (* (.width result) 1.5) (* (.height result) 1.5))) (Highgui/imwrite (str output-path title ".jpg") resized-img) (Thread/sleep 1000)))) (println "")
An input file is needed to successfully define the next functions—copy in the RecordIO input file created in the Appendix.
cpflan-128.rec/flan-28.rec cpflan-128.rec/flan-128.rec
From gan.clj
, up to the main function. Network design for either 28x28 or 128x128 images is automatically selected. 128 currently dies with cudaMalloc OOM errors. Tried:
- ndarray/waitall after each iteration
- mxnet pool env var set to max
- using P100 to avoid CUDA 9.0/K80 issue
2.6s
(ns mxnet-gan-flan.gan (:require [clojure.java.io :as io] [clojure.java.shell :refer [sh]] [org.apache.clojure-mxnet.executor :as executor] [org.apache.clojure-mxnet.eval-metric :as eval-metric] [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.initializer :as init] [org.apache.clojure-mxnet.module :as m] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.optimizer :as opt] [org.apache.clojure-mxnet.symbol :as sym] [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util] [mxnet-gan-flan.viz :as viz] [org.apache.clojure-mxnet.context :as context] [think.image.pixel :as pixel] [mikera.image.core :as img]) (:gen-class)) ;; based off of https://medium.com/@julsimon/generative-adversarial-networks-on-apache-mxnet-part-1-b6d39e6b5df1 ;; Use defonce so these can be overridden. (defonce data-dir "/images/") (defonce output-path "/out/") (defonce model-path "/model/") (defonce batch-size 100) (defonce num-epoch 200) (defonce img-size 128) (io/make-parents (str output-path "gout")) (io/make-parents (str model-path "test")) (defn last-saved-model-number [] (some->> "model/" clojure.java.io/file file-seq (filter (.isFile %)) (map (.getName %)) (filter (clojure.string/includes? % "model-d")) (map (re-seq "\d{4}" %)) (map first) (map (when % (Integer/parseInt %))) (sort) (last))) (def flan-iter (mx-io/image-record-iter {:path-imgrec (str "flan-" img-size ".rec") :data-shape [3 img-size img-size] :batch-size batch-size :shuffle true})) (defn normalize-rgb [x] (/ (- x 128.0) 128.0)) (defn normalize-rgb-ndarray [nda] (let [nda-shape (ndarray/shape-vec nda) new-values (mapv (normalize-rgb %) (ndarray/->vec nda))] (ndarray/array new-values nda-shape))) (defn denormalize-rgb [x] (+ (* x 128.0) 128.0)) (defn clip [x] (cond (< x 0) 0 (> x 255) 255 :else (int x))) (defn postprocess-image [img] (let [datas (ndarray/->vec img) image-shape (mx-shape/->vec (ndarray/shape img)) spatial-size (* (get image-shape 2) (get image-shape 3)) pics (doall (partition (* 3 spatial-size) datas)) pixels (mapv (fn [pic] (let [[rs gs bs] (doall (partition spatial-size pic)) this-pixels (mapv (fn [r g b] (pixel/pack-pixel (int (clip (denormalize-rgb r))) (int (clip (denormalize-rgb g))) (int (clip (denormalize-rgb b))) (int 255))) rs gs bs)] this-pixels)) pics) new-pixels (into [] (flatten pixels)) new-image (img/new-image (* 1 (get image-shape 3)) (* batch-size (get image-shape 2))) _ (img/set-pixels new-image (int-array new-pixels))] new-image)) (defn postprocess-write-img [img filename] (img/write (-> (postprocess-image img) (img/zoom 1.5)) filename "png")) (def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1])) ;; -- Start Layers -- ;; currently set for 128x128 images, from ;; https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj (def ndf 28) ;; image height /width (def nc 3) ;; number of channels (def eps (float (+ 1e-5 1e-12))) (def lr 0.0005) ;; learning rate (def beta1 0.5) (def label (sym/variable "label")) (case img-size 28 (do (defn discriminator [] (as-> (sym/variable "data") data (sym/convolution "d1" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter ndf :no-bias true}) (sym/batch-norm "dbn1" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact1" {:data data :act-type "leaky" :slope 0.2}) (sym/convolution "d2" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d3" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) (sym/batch-norm "dbn3" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact3" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d4" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) (sym/flatten "flt" {:data data}) (sym/fully-connected "fc" {:data data :num-hidden 1 :no-bias false}) (sym/logistic-regression-output "dloss" {:data data :label label}))) (defn generator [] (as-> (sym/variable "rand") data (sym/deconvolution "g1" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) (sym/batch-norm "gbn1" {:data data :fix-gamma true :eps eps}) (sym/activation "gact1" {:data data :act-type "relu"}) (sym/deconvolution "g2" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) (sym/batch-norm "gbn2" {:data data :fix-gamma true :eps eps}) (sym/activation "gact2" {:data data :act-type "relu"}) (sym/deconvolution "g3" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter ndf :no-bias true}) (sym/batch-norm "gbn3" {:data data :fix-gamma true :eps eps}) (sym/activation "gact3" {:data data :act-type "relu"}) (sym/deconvolution "g4" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter nc :no-bias true}) (sym/activation "gact4" {:data data :act-type "tanh"})))) 128 (do (defn discriminator [] (as-> (sym/variable "data") data (sym/convolution "d2" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d3" {:data data :kernel [4 4] :pad [2 2] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) (sym/batch-norm "dbn3" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact3" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d4" {:data data :kernel [4 4] :pad [0 0] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) (sym/batch-norm "dbn4" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact4" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d5" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) (sym/batch-norm "dbn5" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact5" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d6" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) (sym/batch-norm "dbn6" {:data data :fix-gamma true :eps eps}) (sym/leaky-re-lu "dact6" {:data data :act_type "leaky" :slope 0.2}) (sym/convolution "d7" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) (sym/flatten "flt" {:data data}) (sym/fully-connected "fc" {:data data :num-hidden 1 :no-bias false}) (sym/logistic-regression-output "dloss" {:data data :label label}))) (defn generator [] (as-> (sym/variable "rand") data (sym/deconvolution "g1" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) (sym/batch-norm "gbn1" {:data data :fix-gamma true :eps eps}) (sym/activation "gact1" {:data data :act-type "relu"}) (sym/deconvolution "g2" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) (sym/batch-norm "gbn2" {:data data :fix-gamma true :eps eps}) (sym/activation "gact2" {:data data :act-type "relu"}) (sym/deconvolution "g3" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter ndf :no-bias true}) (sym/batch-norm "gbn3" {:data data :fix-gamma true :eps eps}) (sym/activation "gact3" {:data data :act-type "relu"}) (sym/deconvolution "g4" {:data data :kernel [4 4] :pad [0 0] :stride [2 2] :num-filter ndf :no-bias true}) (sym/batch-norm "gbn4" {:data data :fix-gamma true :eps eps}) (sym/activation "gact4" {:data data :act-type "relu"}) (sym/deconvolution "g5" {:data data :kernel [4 4] :pad [2 2] :stride [2 2] :num-filter ndf :no-bias true}) (sym/batch-norm "gbn5" {:data data :fix-gamma true :eps eps}) (sym/activation "gact5" {:data data :act-type "relu"}) (sym/deconvolution "g7" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter nc :no-bias true}) (sym/activation "gact7" {:data data :act-type "tanh"})))) (throw (AssertionError. (str "No networks defined for img-size " img-size ".")))) ;; -- End Layers -- (defn save-img-gout [i n x] (do (viz/im-sav {:title (str "gout-" i "-" n) :output-path output-path :x x :flip false}))) (defn save-img-diff [i n x] (do (viz/im-sav {:title (str "diff-" i "-" n) :output-path output-path :x x :flip false}))) (defn save-img-data [i n batch] (do (viz/im-sav {:title (str "data-" i "-" n) :output-path output-path :x batch :flip false}))) (defn calc-diff [i n diff-d] (let [diff (ndarray/copy diff-d) arr (ndarray/->vec diff) mean (/ (apply + arr) (count arr)) std (let [tmp-a (map (* (- % mean) (- % mean)) arr)] (float (Math/sqrt (/ (apply + tmp-a) (count tmp-a)))))] (let [calc-diff (ndarray/+ (ndarray/div (ndarray/- diff mean) std) 0.5)] (save-img-diff i n calc-diff)))) (defn train [devs] (let [last-train-num (last-saved-model-number) _ (println "The last saved trained epoch is " last-train-num) mod-d (-> (if last-train-num (do (println "Loading discriminator from checkpoint of epoch " last-train-num) (m/load-checkpoint {:contexts devs :data-names ["data"] :label-names ["label"] :prefix (str model-path "model-d") :epoch last-train-num :load-optimizer-states true})) (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})) (m/bind {:data-shapes (mx-io/provide-data flan-iter) :label-shapes (mx-io/provide-label flan-iter) :inputs-need-grad true}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})})) mod-g (-> (if last-train-num (do (println "Loading generator from checkpoint of epoch " last-train-num) (m/load-checkpoint {:contexts devs :data-names ["rand"] :label-names [""] :prefix (str model-path "model-g") :epoch last-train-num :load-optimizer-states true})) (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})) (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))] (println "Training for " num-epoch " epochs...") (doseq [i (if last-train-num (range (inc last-train-num) (inc (+ last-train-num num-epoch))) (range num-epoch))] (mx-io/reduce-batches flan-iter (fn [n batch] (let [rbatch (mx-io/next rand-noise-iter) dbatch (mapv normalize-rgb-ndarray (mx-io/batch-data batch)) out-g (-> mod-g (m/forward rbatch) (m/outputs)) ;; update the discriminiator on the fake grads-f (mapv (ndarray/copy (first %)) (-> mod-d (m/forward { :data (first out-g) :label [(ndarray/zeros [batch-size])]}) (m/backward) (m/grad-arrays))) ;; update the discrimintator on the real grads-r (-> mod-d (m/forward { :data dbatch :label [(ndarray/ones [batch-size])]}) (m/backward) (m/grad-arrays)) _ (mapv (fn [real fake] (let [r (first real)] (ndarray/set r (ndarray/+ r fake)))) grads-r grads-f) _ (m/update mod-d) ;; update the generator diff-d (-> mod-d (m/forward { :data (first out-g) :label [(ndarray/ones [batch-size])]}) (m/backward) (m/input-grads)) _ (-> mod-g (m/backward (first diff-d)) (m/update))] (when (zero? n) (println "iteration = " i "number = " n) (save-img-gout i n (ndarray/copy (ffirst out-g))) (save-img-data i n (first dbatch)) (calc-diff i n (ffirst diff-d)) (m/save-checkpoint mod-g { :prefix (str model-path "model-g") :epoch i :save-opt-states true}) (m/save-checkpoint mod-d { :prefix (str model-path "model-d") :epoch i :save-opt-states true})) (inc n))))))) (defn -main [& args] (let [[dev dev-num] args devs (if (= dev ":gpu") (mapv (context/gpu %) (range (Integer/parseInt (or dev-num "1")))) (mapv (context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] (println "Running with context devices of" devs) (train devs)))
mxnet-gan-flan.gan/-main
2. Custard
2.1. Train
Copy in the input file.
This will run on a GPU.
(ns mxnet-gan-flan.gan) (def data-dir "/images/") (def output-path "/out/") (def model-path "/model/") (def batch-size 100) (def num-epoch 201) (def img-size 28) (-main ":gpu")
2.2. Explore
Copy in the RecordIO input file, created in the Appendix.
tar -zxfgenerator-checkpoints.tgz
find / -xdev -iname "libcudart*"
The explore
function needs modification since we don't display images directly.
15.1s
Explore (Clojure)
(ns mxnet-gan-flan.explore (:require [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.initializer :as init] [org.apache.clojure-mxnet.module :as m] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.optimizer :as opt] [org.apache.clojure-mxnet.symbol :as sym] [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util] [mxnet-gan-flan.viz :as viz] [org.apache.clojure-mxnet.context :as context] [think.image.pixel :as pixel] [mikera.image.core :as img]) (:gen-class)) (def batch-size 9) (def lr 0.0005) ;; learning rate (def beta1 0.5) (def exout-path "/results/") (def model-path "/model/") (def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1])) (defn explore "Use this to explore your models that you have trained. Use the epoch that you wish to load and the number of pictures that you want to generate." ([epoch num] (explore (str model-path "model-g") epoch num)) ([prefix epoch num] (let [mod-g (-> (m/load-checkpoint {:contexts [(context/default-context)] :data-names ["rand"] :label-names [""] :prefix prefix :epoch epoch}) (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))] (println "Generating images from " epoch) (dotimes [i num] (let [rbatch (mx-io/next rand-noise-iter) out-g (-> mod-g (m/forward rbatch) (m/outputs))] (viz/im-sav {:title (str "explore-" epoch "-" i) :output-path exout-path :x (ffirst out-g) :flip false})))))) (defn explore-pretrained "Use this to explore the pretrained model of flans. Specify the number of pictures that you want to generate." [num] (explore (str "pre-trained/" "model-g") 195 num))
mxnet-gan-flan.explore/explore-pretrained
(for [epoch (range 0 201 50)] (explore epoch 1))
List(5) (nil, nil, nil, nil, nil)