55from shutil import copyfile
66import logging
77from math import ceil
8+ import numpy as np
9+ import json
10+ import GPUtil
811from neuralstyle .utils import filename , fileext
912from neuralstyle .imagemagick import (convert , resize , shape , assertshape , choptiles , feather , smush , composite ,
1013 extractalpha , mergealpha )
2932 "-num_iterations" , "500"
3033 ]
3134 },
35+ "gatys-multiresolution" : {},
3236 "chen-schmidt" : {
3337 "folder" : "/app/style-swap" ,
3438 "command" : "th style-swap.lua" ,
4650 }
4751}
4852
53+ # Load file with GPU configuration
54+ with open ("gpuconfig.json" , "r" ) as f :
55+ GPUCONFIG = json .load (f )
56+
4957
5058def styletransfer (contents , styles , savefolder , size = None , alg = "gatys" , weights = None , stylescales = None ,
51- maxtilesize = 400 , tileoverlap = 100 , algparams = None ):
59+ tileoverlap = 100 , algparams = None ):
5260 """General style transfer routine over multiple sets of options"""
5361 # Check arguments
5462 if alg not in ALGORITHMS .keys ():
5563 raise ValueError ("Unrecognized algorithm %s, must be one of %s" % (alg , str (list (ALGORITHMS .keys ()))))
5664
5765 # Plug default options
58- if alg != "gatys" :
66+ if alg != "gatys" and alg != "gatys-multiresolution" :
5967 if weights is not None :
6068 LOGGER .warning ("Only gatys algorithm accepts style weights. Ignoring style weight parameters" )
6169 weights = [None ]
@@ -64,8 +72,6 @@ def styletransfer(contents, styles, savefolder, size=None, alg="gatys", weights=
6472 weights = [5.0 ]
6573 if stylescales is None :
6674 stylescales = [1.0 ]
67- if maxtilesize is None :
68- maxtilesize = 400
6975 if tileoverlap is None :
7076 tileoverlap = 100
7177 if algparams is None :
@@ -75,13 +81,13 @@ def styletransfer(contents, styles, savefolder, size=None, alg="gatys", weights=
7581 for content , style , weight , scale in product (contents , styles , weights , stylescales ):
7682 outfile = outname (savefolder , content , style , alg , scale , weight )
7783 # If the desired size is smaller than the maximum tile size, use a direct neural style
78- if fitsingletile (targetshape (content , size ), maxtilesize ):
84+ if fitsingletile (targetshape (content , size ), alg ):
7985 styletransfer_single (content = content , style = style , outfile = outfile , size = size , alg = alg , weight = weight ,
8086 stylescale = scale , algparams = algparams )
8187 # Else use a tiling strategy
8288 else :
83- neuraltile (content = content , style = style , outfile = outfile , size = size , maxtilesize = maxtilesize ,
84- overlap = tileoverlap , alg = alg , weight = weight , stylescale = scale , algparams = algparams )
89+ neuraltile (content = content , style = style , outfile = outfile , size = size , overlap = tileoverlap , alg = alg ,
90+ weight = weight , stylescale = scale , algparams = algparams )
8591
8692
8793def styletransfer_single (content , style , outfile , size = None , alg = "gatys" , weight = 5.0 , stylescale = 1.0 , algparams = None ):
@@ -101,6 +107,8 @@ def styletransfer_single(content, style, outfile, size=None, alg="gatys", weight
101107 algfile = workdir .name + "/" + "algoutput.png"
102108 if alg == "gatys" :
103109 gatys (rgbfile , stylepng , algfile , size , weight , stylescale , algparams )
110+ elif alg == "gatys-multiresolution" :
111+ gatys_multiresolution (rgbfile , stylepng , algfile , size , weight , stylescale , algparams )
104112 elif alg in ["chen-schmidt" , "chen-schmidt-inverse" ]:
105113 chenschmidt (alg , rgbfile , stylepng , algfile , size , stylescale , algparams )
106114 # Enforce correct size
@@ -111,8 +119,8 @@ def styletransfer_single(content, style, outfile, size=None, alg="gatys", weight
111119 mergealpha (algfile , alphafile , outfile )
112120
113121
114- def neuraltile (content , style , outfile , size = None , maxtilesize = 400 , overlap = 100 , alg = "gatys" , weight = 5.0 ,
115- stylescale = 1.0 , algparams = None ):
122+ def neuraltile (content , style , outfile , size = None , overlap = 100 , alg = "gatys" , weight = 5.0 , stylescale = 1 .0 ,
123+ algparams = None ):
116124 """Strategy to generate a high resolution image by running style transfer on overlapping image tiles"""
117125 LOGGER .info ("Starting tiling strategy" )
118126 if algparams is None :
@@ -123,7 +131,7 @@ def neuraltile(content, style, outfile, size=None, maxtilesize=400, overlap=100,
123131 fullshape = targetshape (content , size )
124132
125133 # Compute number of tiles required to map all the image
126- xtiles , ytiles = tilegeometry (fullshape , maxtilesize , overlap )
134+ xtiles , ytiles = tilegeometry (fullshape , alg , overlap )
127135
128136 # First scale image to target resolution
129137 firstpass = workdir .name + "/" + "lowres.png"
@@ -187,6 +195,69 @@ def gatys(content, style, outfile, size, weight, stylescale, algparams):
187195 tmpout .close ()
188196
189197
198+ def gatys_multiresolution (content , style , outfile , size , weight , stylescale , algparams , startres = 256 ):
199+ """Runs a multiresolution version of Gatys et al method
200+
201+ The multiresolution strategy starts by generating a small image, then using that image as initializer
202+ for higher resolution images. This procedure is repeated up to the tilesize.
203+
204+ Once the maximum tile size attainable by L-BFGS is reached, more iterations are run by using Adam. This allows
205+ to produce larger images using this method than the basic Gatys.
206+
207+ References:
208+ * Gatys et al - Controlling Perceptual Factors in Neural Style Transfer (https://arxiv.org/abs/1611.07865)
209+ * https://gist.github.com/jcjohnson/ca1f29057a187bc7721a3a8c418cc7db
210+ """
211+ # Multiresolution strategy: list of rounds, each round composed of a optimization method and a number of
212+ # upresolution steps.
213+ # Using "adam" as optimizer means that Adam will be used when necessary to attain higher resolutions
214+ strategy = [
215+ ["lbfgs" , 7 ],
216+ ["lbfgs" , 7 ],
217+ ["lbfgs" , 7 ],
218+ ["lbfgs" , 7 ],
219+ ["lbfgs" , 7 ]
220+ ]
221+ LOGGER .info ("Starting gatys-multiresolution with strategy " + str (strategy ))
222+
223+ # Initialization
224+ workdir = TemporaryDirectory ()
225+ maxres = targetshape (content , size )[0 ]
226+ if maxres < startres :
227+ LOGGER .warning ("Target resolution (%d) might too small for the multiresolution method to work well" % maxres )
228+ startres = maxres / 2.0
229+ seed = None
230+ tmpout = workdir .name + "/tmpout.png"
231+
232+ # Iterate over rounds
233+ for roundnumber , (optimizer , steps ) in enumerate (strategy ):
234+ LOGGER .info ("gatys-multiresolution round %d with %s optimizer and %d steps" % (roundnumber , optimizer , steps ))
235+ roundmax = min (maxtile ("gatys" ), maxres ) if optimizer == "lbfgs" else maxres
236+ resolutions = np .linspace (startres , roundmax , steps , dtype = int )
237+ iters = 1000
238+ for stepnumber , res in enumerate (resolutions ):
239+ stepopt = "adam" if res > maxtile ("gatys" ) else "lbfgs"
240+ LOGGER .info ("Step %d, resolution %d, optimizer %s" % (stepnumber , res , stepopt ))
241+ passparams = algparams [:]
242+ passparams .extend ([
243+ "-num_iterations" , iters ,
244+ "-tv_weight" , "0" ,
245+ "-print_iter" , "0" ,
246+ "-optimizer" , stepopt
247+ ])
248+ if seed is not None :
249+ passparams .extend ([
250+ "-init" , "image" ,
251+ "-init_image" , seed
252+ ])
253+ gatys (content , style , tmpout , res , weight , stylescale , passparams )
254+ seed = workdir .name + "/seed.png"
255+ copyfile (tmpout , seed )
256+ iters = max (iters / 2.0 , 100 )
257+
258+ convert (tmpout , outfile )
259+
260+
190261def chenschmidt (alg , content , style , outfile , size , stylescale , algparams ):
191262 """Runs Chen and Schmidt fast style-transfer algorithm
192263
@@ -250,16 +321,20 @@ def correctshape(result, original, size=None):
250321 assertshape (result , targetshape (original , size ))
251322
252323
253- def tilegeometry (imshape , maxtilesize = 400 , overlap = 50 ):
324+ def tilegeometry (imshape , alg , overlap = 50 ):
254325 """Given the shape of an image, computes the number of X and Y tiles to cover it"""
326+ maxtilesize = maxtile (alg )
255327 xtiles = ceil (float (imshape [0 ] - maxtilesize ) / float (maxtilesize - overlap ) + 1 )
256328 ytiles = ceil (float (imshape [1 ] - maxtilesize ) / float (maxtilesize - overlap ) + 1 )
257329 return xtiles , ytiles
258330
259331
260- def fitsingletile (imshape , maxtilesize ):
261- """Returns whether a given image shape will fit in a single tile or not"""
262- return all ([x <= maxtilesize for x in imshape ])
332+ def fitsingletile (imshape , alg ):
333+ """Returns whether a given image shape will fit in a single tile or not.
334+
335+ This depends on the algorithm used and the GPU available in the system"""
336+ mx = maxtile (alg )
337+ return mx * mx >= np .prod (imshape )
263338
264339
265340def targetshape (content , size = None ):
@@ -272,3 +347,24 @@ def targetshape(content, size=None):
272347 return contentshape
273348 else :
274349 return [size , int (size * contentshape [1 ] / contentshape [0 ])]
350+
351+
352+ def gpuname ():
353+ """Returns the model name of the first available GPU"""
354+ gpus = GPUtil .getGPUs ()
355+ if len (gpus ) == 0 :
356+ raise ValueError ("No GPUs detected in the system" )
357+ return gpus [0 ].name
358+
359+
360+ def maxtile (alg = "gatys" ):
361+ """Returns the recommended configuration maximum tile size, based on the available GPU and algorithm to be run
362+
363+ The size returned should be understood as the maximum tile size for a square tile. If non-square tiles are used,
364+ a maximum tile of the same number of pixels should be used.
365+ """
366+ gname = gpuname ()
367+ if gname not in GPUCONFIG :
368+ LOGGER .warning ("Unknown GPU model %s, will use default tiling parameters" )
369+ gname = "default"
370+ return GPUCONFIG [gname ][alg ]
0 commit comments