MXNet GANs with Flans

1. Setup

1.1. Functions

We'll put functions here for later re-use via code imports. This can be run, to test code validity.

From viz.clj.

6.4s
viz.cljSetup (Clojure)
(ns mxnet-gan-flan.viz
  (:require [org.apache.clojure-mxnet.ndarray :as ndarray]
            [org.apache.clojure-mxnet.shape :as mx-shape]
            [org.apache.clojure-mxnet.io :as mx-io])
  (:import (nu.pattern OpenCV)
           (org.opencv.core Core CvType Mat Size)
           (org.opencv.imgproc Imgproc)
           (org.opencv.highgui Highgui)))

;;; Viz stuff
(OpenCV/loadShared)

(defn clip [x]
  (->> x
       (mapv #(* 255 %))
       (mapv #(cond
                (< % 0) 0
                (> % 255) 255
                :else (int %)))
       (mapv #(.byteValue %))))

(defn get-img [raw-data channels height width flip]
  (let [totals (* height width)
        img (if (> channels 1)
              ;; rgb image
              (let [[ra ga ba] (doall (partition totals raw-data))
                    rr (new Mat height width (CvType/CV_8U))
                    gg (new Mat height width (CvType/CV_8U))
                    bb (new Mat height width (CvType/CV_8U))
                    result (new Mat height width (CvType/CV_8U))]
                (do
                  (.put rr 0 0 (byte-array ra))
                  (.put gg 0 0 (byte-array ga))
                  (.put bb 0 0 (byte-array ba)))
                (Core/merge (java.util.ArrayList. [bb gg rr]) result)
                result)
              ;; gray image
              (let [result (new Mat height width (CvType/CV_8U))
                    _ (.put result (int 0) (int 0) (byte-array raw-data))]
                result))]
    (do
      (if flip
        (let [result (new Mat)
              _ (Core/flip img result (int 0))]
          result)
        img))))

(defn im-sav [{:keys [title output-path x flip]
               :or {flip false} :as g-mod}]
  (let [shape (mx-shape/->vec (ndarray/shape x))
        _ (assert (== 4 (count shape)))
        [n c h w] shape
        totals (* h w)
        raw-data (byte-array (clip (ndarray/to-array x)))
        row (.intValue (Math/sqrt n))
        col row
        line-arrs (into [] (partition (* col c totals) raw-data))
        line-mats (mapv (fn [line]
                          (let [img-arr (into [] (partition 
                                                  (* c totals) line))
                                col-mats (new Mat)
                                src (mapv (fn [arr] 
                                            (get-img (into [] arr) 
                                                     c h w flip)) img-arr)
                                _ (Core/hconcat 
                                   (java.util.ArrayList. src) col-mats)]
                           col-mats))
                        line-arrs)
        result (new Mat)
        resized-img (new Mat)
        _ (Core/vconcat (java.util.ArrayList. line-mats) result)]
    (do
      (Imgproc/resize result resized-img (new Size (* (.width result) 1.5) 
                                              (* (.height result) 1.5)))
      (Highgui/imwrite (str output-path title ".jpg") resized-img)
      (Thread/sleep 1000))))

(println "")

An input file is needed to successfully define the next functions—copy in the RecordIO input file created in the Appendix.

cp 
flan-128.rec
/flan-28.rec cp
flan-128.rec
/flan-128.rec

From gan.clj, up to the main function. Network design for either 28x28 or 128x128 images is automatically selected. 128 currently dies with cudaMalloc OOM errors. Tried:

  • ndarray/waitall after each iteration
  • mxnet pool env var set to max
  • using P100 to avoid CUDA 9.0/K80 issue
2.6s
gan.cljSetup (Clojure)
(ns mxnet-gan-flan.gan
  (:require [clojure.java.io :as io]
            [clojure.java.shell :refer [sh]]
            [org.apache.clojure-mxnet.executor :as executor]
            [org.apache.clojure-mxnet.eval-metric :as eval-metric]
            [org.apache.clojure-mxnet.io :as mx-io]
            [org.apache.clojure-mxnet.initializer :as init]
            [org.apache.clojure-mxnet.module :as m]
            [org.apache.clojure-mxnet.ndarray :as ndarray]
            [org.apache.clojure-mxnet.optimizer :as opt]
            [org.apache.clojure-mxnet.symbol :as sym]
            [org.apache.clojure-mxnet.shape :as mx-shape]
            [org.apache.clojure-mxnet.util :as util]
            [mxnet-gan-flan.viz :as viz]
            [org.apache.clojure-mxnet.context :as context]
            [think.image.pixel :as pixel]
            [mikera.image.core :as img])
  (:gen-class))

;; based off of https://medium.com/@julsimon/generative-adversarial-networks-on-apache-mxnet-part-1-b6d39e6b5df1

;; Use defonce so these can be overridden.
(defonce data-dir "/images/")
(defonce output-path "/out/")
(defonce model-path "/model/")
(defonce batch-size 100)
(defonce num-epoch 200)
(defonce img-size 128)

(io/make-parents (str output-path "gout"))
(io/make-parents (str model-path "test"))

(defn last-saved-model-number []
  (some->> "model/"
           clojure.java.io/file
           file-seq
           (filter #(.isFile %))
           (map #(.getName %))
           (filter #(clojure.string/includes? %  "model-d"))
           (map #(re-seq #"\d{4}" %))
           (map first)
           (map #(when % (Integer/parseInt %)))
           (sort)
           (last)))

(def flan-iter (mx-io/image-record-iter {:path-imgrec 
                                         (str "flan-" img-size ".rec")
                                         :data-shape [3 img-size img-size]
                                         :batch-size batch-size
                                         :shuffle true}))

(defn normalize-rgb [x]
  (/ (- x 128.0) 128.0))

(defn normalize-rgb-ndarray [nda]
  (let [nda-shape (ndarray/shape-vec nda)
        new-values (mapv #(normalize-rgb %) (ndarray/->vec nda))]
    (ndarray/array new-values nda-shape)))

(defn denormalize-rgb [x]
  (+ (* x 128.0) 128.0))

(defn clip [x]
  (cond
    (< x 0) 0
    (> x 255) 255
    :else (int x)))

(defn postprocess-image [img]
  (let [datas (ndarray/->vec img)
        image-shape (mx-shape/->vec (ndarray/shape img))
        spatial-size (* (get image-shape 2) (get image-shape 3))
        pics (doall (partition (* 3 spatial-size) datas))
        pixels  (mapv
                 (fn [pic]
                   (let [[rs gs bs] (doall (partition spatial-size pic))
                         this-pixels (mapv (fn [r g b]
                                             (pixel/pack-pixel
                                               (int (clip 
                                                 (denormalize-rgb r)))
                                               (int (clip 
                                                 (denormalize-rgb g)))
                                               (int (clip 
                                                 (denormalize-rgb b)))
                                               (int 255)))
                                           rs gs bs)]
                     this-pixels))
                 pics)
        new-pixels (into [] (flatten pixels))
        new-image (img/new-image (* 1 (get image-shape 3)) 
                                 (* batch-size (get image-shape 2)))
        _  (img/set-pixels new-image (int-array new-pixels))]
    new-image))

(defn postprocess-write-img [img filename]
  (img/write (-> (postprocess-image img)
                 (img/zoom 1.5)) filename "png"))

(def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1]))

;; -- Start Layers --
;; currently set for 128x128 images, from
;; https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj

(def ndf 28) ;; image height /width
(def nc 3) ;; number of channels
(def eps (float (+ 1e-5  1e-12)))
(def lr  0.0005) ;; learning rate
(def beta1 0.5)

(def label (sym/variable "label"))

(case img-size
  28 (do
    (defn discriminator []
      (as-> (sym/variable "data") data
        (sym/convolution "d1" {:data data :kernel [4 4] :pad [3 3] 
          :stride [2 2] :num-filter ndf :no-bias true})
        (sym/batch-norm "dbn1" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact1" {:data data :act-type "leaky" :slope 0.2})

        (sym/convolution "d2" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
        (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d3" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter (* 3 ndf) :no-bias true})
        (sym/batch-norm "dbn3" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact3" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d4" {:data data :kernel [4 4] :pad [0 0] 
          :stride [1 1] :num-filter (* 4 ndf) :no-bias true})
        (sym/flatten "flt" {:data data})

        (sym/fully-connected "fc" {:data data :num-hidden 1 :no-bias false})
        (sym/logistic-regression-output "dloss" {:data data :label label})))
    (defn generator []
      (as-> (sym/variable "rand") data
        (sym/deconvolution "g1" {:data data :kernel [4 4]  :pad [0 0] 
        :stride [1 1] :num-filter (* 4 ndf) :no-bias true})
        (sym/batch-norm "gbn1" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact1" {:data data :act-type "relu"})

        (sym/deconvolution "g2" {:data data :kernel [4 4] :pad [1 1] 
        :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
        (sym/batch-norm "gbn2" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact2" {:data data :act-type "relu"})

        (sym/deconvolution "g3" {:data data :kernel [4 4] :pad [1 1] 
        :stride [2 2] :num-filter ndf :no-bias true})
        (sym/batch-norm "gbn3" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact3" {:data data :act-type "relu"})

        (sym/deconvolution "g4" {:data data :kernel [4 4] :pad [3 3] 
        :stride [2 2] :num-filter nc :no-bias true})
        (sym/activation "gact4" {:data data :act-type "tanh"}))))
  128 (do
    (defn discriminator []
      (as-> (sym/variable "data") data
        (sym/convolution "d2" {:data data :kernel [4 4] :pad [3 3] 
          :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
        (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d3" {:data data :kernel [4 4] :pad [2 2] 
          :stride [2 2] :num-filter (* 3 ndf) :no-bias true})
        (sym/batch-norm "dbn3" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact3" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d4" {:data data :kernel [4 4] :pad [0 0] 
          :stride [2 2] :num-filter (* 3 ndf) :no-bias true})
        (sym/batch-norm "dbn4" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact4" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d5" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter (* 3 ndf) :no-bias true})
        (sym/batch-norm "dbn5" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact5" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d6" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter (* 3 ndf) :no-bias true})
        (sym/batch-norm "dbn6" {:data data :fix-gamma true :eps eps})
        (sym/leaky-re-lu "dact6" {:data data :act_type "leaky" :slope 0.2})

        (sym/convolution "d7" {:data data :kernel [4 4] :pad [0 0] 
          :stride [1 1] :num-filter (* 4 ndf) :no-bias true})
        (sym/flatten "flt" {:data data})

        (sym/fully-connected "fc" {:data data :num-hidden 1 :no-bias false})
        (sym/logistic-regression-output "dloss" {:data data :label label})))
    (defn generator []
      (as-> (sym/variable "rand") data
        (sym/deconvolution "g1" {:data data :kernel [4 4]  :pad [0 0] 
          :stride [1 1] :num-filter (* 4 ndf) :no-bias true})
        (sym/batch-norm "gbn1" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact1" {:data data :act-type "relu"})

        (sym/deconvolution "g2" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter (* 2 ndf) :no-bias true})
        (sym/batch-norm "gbn2" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact2" {:data data :act-type "relu"})

        (sym/deconvolution "g3" {:data data :kernel [4 4] :pad [1 1] 
          :stride [2 2] :num-filter ndf :no-bias true})
        (sym/batch-norm "gbn3" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact3" {:data data :act-type "relu"})

        (sym/deconvolution "g4" {:data data :kernel [4 4] :pad [0 0] 
          :stride [2 2] :num-filter ndf :no-bias true})
        (sym/batch-norm "gbn4" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact4" {:data data :act-type "relu"})

        (sym/deconvolution "g5" {:data data :kernel [4 4] :pad [2 2] 
          :stride [2 2] :num-filter ndf :no-bias true})
        (sym/batch-norm "gbn5" {:data data :fix-gamma true :eps eps})
        (sym/activation "gact5" {:data data :act-type "relu"})

        (sym/deconvolution "g7" {:data data :kernel [4 4] :pad [3 3] 
          :stride [2 2] :num-filter nc :no-bias true})
        (sym/activation "gact7" {:data data :act-type "tanh"}))))
  (throw (AssertionError. 
          (str "No networks defined for img-size " img-size "."))))

;; -- End Layers -- 

(defn save-img-gout [i n x]
  (do
    (viz/im-sav {:title (str "gout-" i "-" n)
                 :output-path output-path
                 :x x
                 :flip false})))

(defn save-img-diff [i n x]
  (do (viz/im-sav {:title (str "diff-" i "-" n)
                   :output-path output-path
                   :x x
                   :flip false})))

(defn save-img-data [i n batch]
  (do (viz/im-sav {:title (str "data-" i "-" n)
                   :output-path output-path
                   :x batch
                   :flip false})))

(defn calc-diff [i n diff-d]
  (let [diff (ndarray/copy diff-d)
        arr (ndarray/->vec diff)
        mean (/ (apply + arr) (count arr))
        std (let [tmp-a (map #(* (- % mean) (- % mean)) arr)]
              (float (Math/sqrt (/ (apply + tmp-a) (count tmp-a)))))]
    (let [calc-diff (ndarray/+ (ndarray/div (ndarray/- diff mean) std) 0.5)]
      (save-img-diff i n calc-diff))))

(defn train [devs]
  (let [last-train-num (last-saved-model-number)
        _ (println "The last saved trained epoch is " last-train-num)
        mod-d  (-> (if last-train-num
                     (do
                       (println 
                         "Loading discriminator from checkpoint of epoch " 
                         last-train-num)
                       (m/load-checkpoint {:contexts devs
                                           :data-names ["data"]
                                           :label-names ["label"]
                                           :prefix (str model-path "model-d")
                                           :epoch last-train-num
                                           :load-optimizer-states true}))
                     (m/module (discriminator) {:contexts devs 
                                                :data-names ["data"] 
                                                :label-names ["label"]}))
                   (m/bind {:data-shapes (mx-io/provide-data flan-iter)
                            :label-shapes (mx-io/provide-label flan-iter)
                            :inputs-need-grad true})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr 
                                                            :wd 0.0 
                                                            :beta1 beta1})}))
        mod-g (-> (if last-train-num
                    (do
                     (println "Loading generator from checkpoint of epoch " 
                              last-train-num)
                     (m/load-checkpoint {:contexts devs
                                         :data-names ["rand"]
                                         :label-names [""]
                                         :prefix (str model-path "model-g")
                                         :epoch last-train-num
                                         :load-optimizer-states true}))
                    (m/module (generator) {:contexts devs 
                                           :data-names ["rand"] 
                                           :label-names nil}))
                  (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
                  (m/init-params {:initializer (init/normal 0.02)})
                  (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr 
                                                           :wd 0.0 
                                                           :beta1 beta1})}))]

    (println "Training for " num-epoch " epochs...")
    (doseq [i (if last-train-num
                (range (inc last-train-num) (inc (+ last-train-num 
                                                    num-epoch)))
                (range num-epoch))]
      (mx-io/reduce-batches flan-iter
                            (fn [n batch]
                              (let [rbatch (mx-io/next rand-noise-iter)
                                    dbatch (mapv normalize-rgb-ndarray 
                                                 (mx-io/batch-data batch))
                                    out-g (-> mod-g
                                              (m/forward rbatch)
                                              (m/outputs))
                                    ;; update the discriminiator on the fake
                                    grads-f (mapv #(ndarray/copy (first %)) 
                                              (-> mod-d 
                                                (m/forward {
                                                  :data (first out-g) 
                                                  :label [(ndarray/zeros 
                                                         [batch-size])]})
                                                (m/backward)
                                                (m/grad-arrays)))
                                    ;; update the discrimintator on the real
                                    grads-r (-> mod-d
                                                (m/forward {
                                                  :data dbatch 
                                                  :label [(ndarray/ones 
                                                         [batch-size])]})
                                                (m/backward)
                                                (m/grad-arrays))
                                    _ (mapv (fn [real fake] 
                                              (let [r (first real)]
                                                (ndarray/set 
                                                  r (ndarray/+ r fake)))) 
                                         grads-r grads-f)
                                    _ (m/update mod-d)
                                   ;; update the generator
                                    diff-d (-> mod-d
                                             (m/forward {
                                               :data (first out-g) 
                                               :label [(ndarray/ones 
                                                      [batch-size])]})
                                             (m/backward)
                                             (m/input-grads))
                                    _ (-> mod-g
                                          (m/backward (first diff-d))
                                          (m/update))]
                                (when (zero? n)
                                  (println "iteration = " i  "number = " n)
                                  (save-img-gout 
                                    i n (ndarray/copy (ffirst out-g)))
                                  (save-img-data i n (first dbatch))
                                  (calc-diff i n (ffirst diff-d))
                                  (m/save-checkpoint mod-g {
                                    :prefix (str model-path "model-g") 
                                    :epoch i :save-opt-states true})
                                  (m/save-checkpoint mod-d {
                                    :prefix (str model-path "model-d") 
                                    :epoch i :save-opt-states true}))
                                (inc n)))))))

(defn -main [& args]
  (let [[dev dev-num] args
        devs (if (= dev ":gpu")
               (mapv #(context/gpu %) (range (Integer/parseInt 
                                              (or dev-num "1"))))
               (mapv #(context/cpu %) (range (Integer/parseInt 
                                              (or dev-num "1")))))]
    (println "Running with context devices of" devs)
    (train devs)))
mxnet-gan-flan.gan/-main

2. Custard

2.1. Train

Copy in the input file.

This will run on a GPU.

(ns mxnet-gan-flan.gan)
(def data-dir "/images/")
(def output-path "/out/")
(def model-path "/model/")
(def batch-size 100)
(def num-epoch 201)
(def img-size 28)
  



(-main ":gpu")
tar -zcf /results/generator-checkpoints.tgz \
  model/model-g-*0.* model/model-g-symbol.json
generator-checkpoints.tgz

2.2. Explore

Copy in the RecordIO input file, created in the Appendix.

tar -zxf 
generator-checkpoints.tgz
find / -xdev -iname "libcudart*"

The explore function needs modification since we don't display images directly.

15.1s
Explore (Clojure)


(ns mxnet-gan-flan.explore
  (:require [org.apache.clojure-mxnet.io :as mx-io]
            [org.apache.clojure-mxnet.initializer :as init]
            [org.apache.clojure-mxnet.module :as m]
            [org.apache.clojure-mxnet.ndarray :as ndarray]
            [org.apache.clojure-mxnet.optimizer :as opt]
            [org.apache.clojure-mxnet.symbol :as sym]
            [org.apache.clojure-mxnet.shape :as mx-shape]
            [org.apache.clojure-mxnet.util :as util]
            [mxnet-gan-flan.viz :as viz]
            [org.apache.clojure-mxnet.context :as context]
            [think.image.pixel :as pixel]
            [mikera.image.core :as img])
  (:gen-class))

(def batch-size 9)
(def lr  0.0005) ;; learning rate
(def beta1 0.5)

(def exout-path "/results/")
(def model-path "/model/")

(def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1]))

(defn explore
  "Use this to explore your models that you have trained.
   Use the epoch that you wish to load and the number of 
   pictures that you want to generate."
  ([epoch num]
   (explore (str model-path "model-g") epoch num))
  ([prefix epoch num]

   (let [mod-g (-> (m/load-checkpoint {:contexts [(context/default-context)]
                                       :data-names ["rand"]
                                       :label-names [""]
                                       :prefix prefix
                                       :epoch epoch})
                   (m/bind {:data-shapes (mx-io/provide-data 
                                          rand-noise-iter)})
                   (m/init-params {:initializer (init/normal 0.02)})
                   (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr 
                                                            :wd 0.0 
                                                            :beta1 beta1})}))]
     
     (println "Generating images from " epoch)
     (dotimes [i num]
       (let [rbatch (mx-io/next rand-noise-iter)
             out-g (-> mod-g
                       (m/forward rbatch)
                       (m/outputs))]
         (viz/im-sav {:title (str "explore-" epoch "-" i)
                      :output-path exout-path
                      :x (ffirst out-g)
                      :flip false}))))))
  
(defn explore-pretrained
  "Use this to explore the pretrained model of flans.
   Specify the number of pictures that you want to generate."
  [num]
  (explore (str "pre-trained/" "model-g") 195 num))
mxnet-gan-flan.explore/explore-pretrained
(for [epoch (range 0 201 50)]
  (explore epoch 1))
List(5) (nil, nil, nil, nil, nil)