@@ -300,6 +300,7 @@ def run(
300300 fittingplotname = None ,
301301 interactplotname = None ,
302302 estname = None ,
303+ plot_style = "WTP" ,
303304 ):
304305 """Run the estimation.
305306
@@ -348,11 +349,13 @@ def run(
348349 If ``None``, it will be the current time +
349350 ``"_estimate"``.
350351 Default: ``None``
352+ plot_style : str, optional
353+ Plot stlye. The default is "WTP".
351354 """
352355 if self .setup .dummy :
353356 raise ValueError (
354357 "Estimate: for parameter estimation"
355- + " you can't use a dummy paramter."
358+ " you can't use a dummy paramter."
356359 )
357360 act_time = timemodule .strftime ("%Y-%m-%d_%H-%M-%S" )
358361
@@ -405,42 +408,45 @@ def run(
405408 else :
406409 rank = 0
407410
411+ # initialize the sampler
412+ sampler = spotpy .algorithms .sceua (
413+ self .setup ,
414+ dbname = dbname ,
415+ dbformat = "csv" ,
416+ parallel = parallel ,
417+ save_sim = True ,
418+ db_precision = np .float64 ,
419+ )
420+ # start the estimation with the sce-ua algorithm
408421 if run :
409- # initialize the sampler
410- sampler = spotpy .algorithms .sceua (
411- self .setup ,
412- dbname = dbname ,
413- dbformat = "csv" ,
414- parallel = parallel ,
415- save_sim = True ,
416- db_precision = np .float64 ,
417- )
418- # start the estimation with the sce-ua algorithm
419422 sampler .sample (rep , ngs = 10 , kstop = 100 , pcento = 1e-4 , peps = 1e-3 )
420423
421- if rank == 0 :
422- # save best parameter-set
424+ if rank == 0 :
425+ if run :
423426 self .result = sampler .getdata ()
424- para_opt = spotpy .analyser .get_best_parameterset (
425- self .result , maximize = False
427+ else :
428+ self .result = np .genfromtxt (
429+ dbname + ".csv" , delimiter = "," , names = True
426430 )
427- void_names = para_opt .dtype .names
428- para = []
429- header = []
430- for name in void_names :
431- para .append (para_opt [0 ][name ])
432- header .append (name [3 :])
433- self .estimated_para [header [- 1 ]] = para [- 1 ]
434- np .savetxt (paraname , para , header = " " .join (header ))
435-
436- if rank == 0 :
431+ para_opt = spotpy .analyser .get_best_parameterset (
432+ self .result , maximize = False
433+ )
434+ void_names = para_opt .dtype .names
435+ para = []
436+ header = []
437+ for name in void_names :
438+ para .append (para_opt [0 ][name ])
439+ header .append (name [3 :])
440+ self .estimated_para [header [- 1 ]] = para [- 1 ]
441+ np .savetxt (paraname , para , header = " " .join (header ))
437442 # plot the estimation-results
438443 plotter .plotparatrace (
439444 result = self .result ,
440445 parameternames = paranames ,
441446 parameterlabels = paralabels ,
442447 stdvalues = self .estimated_para ,
443448 plotname = traceplotname ,
449+ style = plot_style ,
444450 )
445451 plotter .plotfit_steady (
446452 setup = self .setup ,
@@ -450,8 +456,11 @@ def run(
450456 radnames = self .radnames ,
451457 extra = self .extra_kw_names ,
452458 plotname = fittingplotname ,
459+ style = plot_style ,
460+ )
461+ plotter .plotparainteract (
462+ self .result , paralabels , interactplotname , style = plot_style
453463 )
454- plotter .plotparainteract (self .result , paralabels , interactplotname )
455464
456465 def sensitivity (
457466 self ,
@@ -462,6 +471,7 @@ def sensitivity(
462471 plotname = None ,
463472 traceplotname = None ,
464473 sensname = None ,
474+ plot_style = "WTP" ,
465475 ):
466476 """Run the sensitivity analysis.
467477
@@ -501,11 +511,13 @@ def sensitivity(
501511 If ``None``, it will be the current time +
502512 ``"_estimate"``.
503513 Default: ``None``
514+ plot_style : str, optional
515+ Plot stlye. The default is "WTP".
504516 """
505517 if len (self .setup .para_names ) == 1 and not self .setup .dummy :
506518 raise ValueError (
507519 "Sensitivity: for estimation with only one parameter"
508- + " you have to use a dummy paramter."
520+ " you have to use a dummy paramter."
509521 )
510522 if rep is None :
511523 rep = spotpylib .fast_rep (
@@ -592,11 +604,14 @@ def sensitivity(
592604 header = " " .join (paranames )
593605 np .savetxt (sensname , sens_est ["ST" ], header = header )
594606 np .savetxt (sensname1 , sens_est ["S1" ], header = header )
595- plotter .plotsensitivity (paralabels , sens_est , plotname )
607+ plotter .plotsensitivity (
608+ paralabels , sens_est , plotname , style = plot_style
609+ )
596610 plotter .plotparatrace (
597611 data ,
598612 parameternames = paranames ,
599613 parameterlabels = paralabels ,
600614 stdvalues = None ,
601615 plotname = traceplotname ,
616+ style = plot_style ,
602617 )
0 commit comments