1 from .metrics
import slope, sspb, mdsa, rmsle
2 from .meta
import get_sensor_label
3 from .utils
import closest_wavelength, ignore_warnings
4 from collections
import defaultdict
as dd
5 from pathlib
import Path
11 Add 1 to 1 diagonal line to a plot.
12 https://stackoverflow.com/questions/22104256/does-matplotlib-have-a-function-for-drawing-diagonal-lines-in-axis-coordinates
14 Usage: add_identity(plt.gca(), color='k', ls='--')
16 line_kwargs[
'label'] = line_kwargs.get(
'label',
'_nolegend_')
17 identity, = ax.plot([], [], *line_args, **line_kwargs)
20 low_x, high_x = ax.get_xlim()
21 low_y, high_y = ax.get_ylim()
22 lo = max(low_x, low_y)
23 hi = min(high_x, high_y)
24 identity.set_data([lo, hi], [lo, hi])
27 ax.callbacks.connect(
'xlim_changed', callback)
28 ax.callbacks.connect(
'ylim_changed', callback)
31 'transform' : ax.transAxes,
32 'textcoords' :
'offset points',
33 'xycoords' :
'axes fraction',
34 'fontname' :
'monospace',
40 ax.annotate(
r'$\mathbf{1:1}$', xy=(0.87,0.99), size=11, **ann_kwargs)
44 def _create_metric(metric, y_true, y_est, longest=None, label=None):
45 ''' Create a position-aligned string which shows the performance via a single metric '''
46 if label ==
None: label = metric.__name__.replace(
'SSPB',
'\\beta').replace(
'MdSA',
'\\varepsilon\\thinspace')
48 if longest ==
None: longest = len(label)
50 ispct = metric.__qualname__
in [
'mape',
'sspb',
'mdsa']
51 diff = longest-len(label.replace(
'^',
''))
52 space =
r''.join([
r'\ ']*diff + [
r'\thinspace']*diff)
53 prec = (1
if abs(metric(y_true, y_est)) < 100
and metric.__name__
not in [
'N']
else 0)
if ispct
or metric.__name__
in [
'N']
else 3
55 stat = f
'{metric(y_true, y_est):.{prec}f}'
56 perc =
r'$\small{\mathsf{\%}}$' if ispct
else ''
57 return rf
'$\mathtt{{{label}}}{space}:$ {stat}{perc}'
59 def _create_stats(y_true, y_est, metrics, title=None):
60 ''' Create stat box strings for all metrics, assuming there is only a single target feature '''
61 longest = max([len(metric.__name__.replace(
'SSPB',
'Bias').replace(
'MdSA',
'Error').replace(
'^',
''))
for metric
in metrics])
62 statbox = [_create_metric(m, y_true, y_est, longest=longest)
for m
in metrics]
65 statbox = [rf
'$\mathbf{{\underline{{{title}}}}}$'] + statbox
68 def _create_multi_feature_stats(y_true, y_est, metrics, labels=None):
69 ''' Create stat box strings for a single metric, assuming there are multiple target features '''
71 labels = [f
'Feature {i}' for i
in range(y_true.shape[1])]
72 assert(len(labels) == y_true.shape[1] == y_est.shape[1]), f
'Number of labels does not match number of features: {labels} - {y_true.shape}'
74 title = metrics[0].__name__.replace(
'SSPB',
'Bias').replace(
'MdSA',
'Error')
75 longest = max([len(label.replace(
'^',
''))
for label
in labels])
76 statbox = [_create_metric(metrics[0], y1, y2, longest=longest, label=lbl)
for y1, y2, lbl
in zip(y_true.T, y_est.T, labels)]
77 statbox = [rf
'$\mathbf{{\underline{{{title}}}}}$'] + statbox
80 def add_stats_box(ax, y_true, y_est, metrics=[mdsa, sspb, slope], bottom_right=False, bottom=False, right=False, x=0.025, y=0.97, fontsize=16, label=None):
81 ''' Add a text box containing a variety of performance statistics, to the given axis '''
82 import matplotlib.pyplot
as plt
83 plt.rc(
'text', usetex=
True)
84 plt.rcParams[
'mathtext.default']=
'regular'
86 create_box = _create_stats
if len(y_true.shape) == 1
or y_true.shape[1] == 1
else _create_multi_feature_stats
87 stats_box =
'\n'.join( create_box(y_true, y_est, metrics, label) )
89 'transform' : ax.transAxes,
90 'textcoords' :
'offset points',
91 'xycoords' :
'axes fraction',
92 'fontname' :
'monospace',
98 'facecolor' :
'white',
99 'edgecolor' :
'black',
104 ann = ax.annotate(stats_box, xy=(x,y), size=fontsize, **ann_kwargs)
106 bottom |= bottom_right
107 right |= bottom_right
110 if bottom
or right
or bottom_right:
111 plt.gcf().canvas.draw()
112 bbox_orig = ann.get_tightbbox(plt.gcf().canvas.renderer).transformed(ax.transAxes.inverted())
117 new_y = bbox_orig.y1 - bbox_orig.y0 + (1 - y)
121 new_x = 1 - (bbox_orig.x1 - bbox_orig.x0) + x
124 ann.xy = (new_x, new_y)
128 def draw_map(*lonlats, scale=0.2, world=False, us=True, eu=False, labels=[], ax=None, gray=False, res='i', **scatter_kws):
129 ''' Helper function to plot locations on a global map '''
130 import matplotlib.pyplot
as plt
131 from matplotlib.transforms
import Bbox
132 from mpl_toolkits.axes_grid1.inset_locator
import TransformedBbox, BboxPatch, BboxConnector
133 from mpl_toolkits.axes_grid1.inset_locator
import zoomed_inset_axes, inset_axes
134 from mpl_toolkits.basemap
import Basemap
135 from itertools
import chain
140 WORLD_MAP = {
'cyl': [-90, 85, -180, 180]}
142 'cyl' : [24, 49, -126, -65],
143 'lcc' : [23, 48, -121, -64],
146 'cyl' : [34, 65, -12, 40],
147 'lcc' : [30.5, 64, -10, 40],
150 def mark_inset(ax, ax2, m, m2, MAP, loc1=(1, 2), loc2=(3, 4), **kwargs):
152 https://stackoverflow.com/questions/41610834/basemap-projection-geos-controlling-mark-inset-location
153 Patched mark_inset to work with Basemap.
154 Reason: Basemap converts Geographic (lon/lat) to Map Projection (x/y) coordinates
156 Additionally: set connector locations separately for both axes:
157 loc1 & loc2: tuple defining start and end-locations of connector 1 & 2
159 axzoom_geoLims = (MAP[
'cyl'][2:], MAP[
'cyl'][:2])
160 rect = TransformedBbox(Bbox(np.array(m(*axzoom_geoLims)).T), ax.transData)
161 pp = BboxPatch(rect, fill=
False, **kwargs)
163 p1 = BboxConnector(ax2.bbox, rect, loc1=loc1[0], loc2=loc1[1], **kwargs)
165 p1.set_clip_on(
False)
166 p2 = BboxConnector(ax2.bbox, rect, loc1=loc2[0], loc2=loc2[1], **kwargs)
168 p2.set_clip_on(
False)
174 kwargs = {
'projection':
'cyl',
'resolution': res}
177 kwargs = {
'projection':
'lcc',
'lat_0':30,
'lon_0':-98,
'resolution': res}
180 kwargs = {
'projection':
'lcc',
'lat_0':48,
'lon_0':27,
'resolution': res}
182 raise Exception(
'Must plot world, US, or EU')
184 kwargs.update(dict(zip([
'llcrnrlat',
'urcrnrlat',
'llcrnrlon',
'urcrnrlon'], MAP[
'lcc' if 'lcc' in MAP
else 'cyl'])))
185 if ax
is None: f = plt.figure(figsize=(PLOT_WIDTH, PLOT_HEIGHT), edgecolor=
'w')
186 m = Basemap(ax=ax, **kwargs)
187 ax = m.ax
if m.ax
is not None else plt.gca()
190 m.readshapefile(Path(__file__).parent.joinpath(
'map_files',
'st99_d00').as_posix(), name=
'states', drawbounds=
True, color=
'k', linewidth=0.5, zorder=11)
191 m.fillcontinents(color=(0,0,0,0), lake_color=
'#9abee0', zorder=9)
193 m.drawrivers(linewidth=0.2, color=
'blue', zorder=9)
194 m.drawcountries(color=
'k', linewidth=0.5)
196 m.drawcountries(color=
'w')
199 if us
or eu: m.shadedrelief(scale=0.3
if world
else 1)
202 m.arcgisimage(service=
'World_Imagery', xpixels = 2000, verbose=
True)
216 colors = [
'aqua',
'orangered',
'xkcd:tangerine',
'xkcd:fresh green',
'xkcd:clay',
'magenta',
'xkcd:sky blue',
'xkcd:greyish blue',
'xkcd:goldenrod', ]
217 markers = [
'o',
'^',
's',
'*',
'v',
'X',
'.',
'x',]
219 assert(len(labels) == len(lonlats)), [len(labels), len(lonlats)]
220 for i, (label, lonlat)
in enumerate(zip(labels, lonlats)):
221 lonlat = np.atleast_2d(lonlat)
222 if 'color' not in scatter_kws
or mod_cr:
223 scatter_kws[
'color'] = colors[i]
224 scatter_kws[
'marker'] = markers[i]
226 ax.scatter(*m(lonlat[:,0], lonlat[:,1]), label=label, zorder=12, **scatter_kws)
227 ax.legend(loc=
'lower left', prop={
'weight':
'bold',
'size':8}).set_zorder(20)
230 for lonlat
in lonlats:
232 lonlat = np.atleast_2d(lonlat)
233 s = ax.scatter(*m(lonlat[:,0], lonlat[:,1]), zorder=12, **scatter_kws)
235 hide_kwargs = {
'axis':
'both',
'which':
'both'}
236 hide_kwargs.update(dict([(k,
False)
for k
in [
'bottom',
'top',
'left',
'right',
'labelleft',
'labelbottom']]))
237 ax.tick_params(**hide_kwargs)
239 for axis
in [
'top',
'bottom',
'left',
'right']:
240 ax.spines[axis].set_linewidth(1.5)
241 ax.spines[axis].set_zorder(50)
247 loc = (0.25, -0.1)
if eu
else (0.35, -0.01)
248 ax_ins = inset_axes(ax, width=PLOT_WIDTH*size, height=PLOT_HEIGHT*size, loc=
'center', bbox_to_anchor=loc, bbox_transform=ax.transAxes, axes_kwargs={
'zorder': 5})
250 scatter_kws.update({
's': 6})
251 m2 =
draw_map(*lonlats, labels=labels, ax=ax_ins, **scatter_kws)
253 mark_inset(ax, ax_ins, m, m2, US_MAP, loc1=(1,1), loc2=(2,2), edgecolor=
'grey', zorder=3)
254 mark_inset(ax, ax_ins, m, m2, US_MAP, loc1=[3,3], loc2=[4,4], edgecolor=
'grey', zorder=0)
258 ax_ins = inset_axes(ax, width=PLOT_WIDTH*size, height=PLOT_HEIGHT*size, loc=
'center', bbox_to_anchor=(0.75, -0.05), bbox_transform=ax.transAxes, axes_kwargs={
'zorder': 5})
260 scatter_kws.update({
's': 6})
261 m2 =
draw_map(*lonlats, us=
False, eu=
True, labels=labels, ax=ax_ins, **scatter_kws)
263 mark_inset(ax, ax_ins, m, m2, EU_MAP, loc1=(1,1), loc2=(2,2), edgecolor=
'grey', zorder=3)
264 mark_inset(ax, ax_ins, m, m2, EU_MAP, loc1=[3,3], loc2=[4,4], edgecolor=
'grey', zorder=0)
270 ''' Helper function to allow defaultdicts whose default value returned is the queried key '''
273 ''' DefaultDict which allows the key as the default value '''
274 def __missing__(self, key):
275 if self.default_factory
is None:
277 val = self[key] = self.default_factory(key)
283 def plot_scatter(y_test, benchmarks, bands, labels, products, sensor, title=None, methods=None, n_col=3, img_outlbl=''):
284 import matplotlib.patheffects
as pe
285 import matplotlib.ticker
as ticker
286 import matplotlib.pyplot
as plt
287 import seaborn
as sns
289 folder = Path(
'scatter_plots')
290 folder.mkdir(exist_ok=
True, parents=
True)
293 'chl' :
'Chl\\textit{a}',
295 'aph' :
'\\textit{a}_{ph}',
297 'cdom':
'\\textit{a}_{CDOM}',
301 'chl' :
'[mg m^{-3}]',
302 'pc' :
'[mg m^{-3}]',
303 'tss' :
'[g m^{-3}]',
312 products = [p
for i,p
in enumerate(np.atleast_1d(products))
if i < y_test.shape[-1]]
314 plt.rc(
'text', usetex=
True)
315 plt.rcParams[
'mathtext.default']=
'regular'
320 if len(labels) > 3
and 'chl' not in products:
322 'default' : [443, 482, 561, 655],
326 target = [
closest_wavelength(w, bands)
for w
in product_bands.get(products[0], product_bands[
'default'])]
327 plot_label = [w
in target
for w
in bands]
328 plot_order = [
'MDN',
'QAA',
'GIOP']
331 plot_label = [
True] * len(labels)
335 if plot_order
is None:
336 if 'chl' in products
and len(products) == 1:
337 benchmarks = benchmarks[
'chl']
338 if 'MLP' in benchmarks:
339 plot_order = [
'MDN',
'MLP',
'SVM',
'XGB',
'KNN',
'OC3']
341 plot_order = [
'MDN',
'Smith_Blend',
'OC6',
'Mishra_NDCI',
'Gons_2band',
'Gilerson_2band']
342 elif len(products) > 1
and any(k
in products
for k
in [
'chl',
'tss',
'cdom']):
343 plot_order = {k:v
for k,v
in {
344 'chl' : [
'MDN',
'Gilerson_2band'],
345 'tss' : [
'MDN',
'SOLID'],
346 'cdom' : [
'MDN',
'Ficek'],
347 }.items()
if k
in products}
348 plot_label = [
True] * len(plot_order)
350 n_col = len(plot_order)
351 plot_order = {p: [
'MDN']
for p
in products}
352 plot_label = [
True] * 4
353 labels = [(p,label)
for label
in labels
for p
in products
if p
in label]
354 print(
'Plotting labels:', [l
for i,l
in enumerate(labels)
if plot_label[i]])
355 assert(len(labels) == y_test.shape[-1]), [len(labels), y_test.shape]
359 n_col = max(n_col, sum(plot_label))
365 fig, axes = plt.subplots(n_row, n_col, figsize=(fig_size*n_col, fig_size*n_row+1))
366 axes = [ax
for axs
in np.atleast_1d(axes).T
for ax
in np.atleast_1d(axs)]
367 colors = [
'xkcd:sky blue',
'xkcd:tangerine',
'xkcd:fresh green',
'xkcd:greyish blue',
'xkcd:goldenrod',
'xkcd:clay',
'xkcd:bluish purple',
'xkcd:reddish']
369 print(
'Order:', plot_order)
370 print(f
'Plot size: {n_row} x {n_col}')
374 full_ax = fig.add_subplot(111, frameon=
False)
375 full_ax.tick_params(labelcolor=
'none', top=
False, bottom=
False, left=
False, right=
False, pad=10)
379 estimate_label =
'Measured'
381 y_pre = estimate_label.replace(
'-',
'\\textbf{-}')
382 plabel = f
'{product_labels[products[0]]} {product_units[products[0]]}'
383 xlabel = fr
'$\mathbf{{{x_pre} {plabel}}}$'
384 ylabel = fr
'$\mathbf{{{y_pre}}}$'+
'' +fr
'$\mathbf{{ {plabel}}}$'
385 if not isinstance(plot_order, dict):
386 full_ax.set_xlabel(xlabel.replace(
' ',
'\ '), fontsize=24, labelpad=10)
387 full_ax.set_ylabel(ylabel.replace(
' ',
'\ '), fontsize=24, labelpad=10)
389 full_ax.set_ylabel(fr
'$\mathbf{{{x_pre}}}$'.replace(
' ',
'\ '), fontsize=24, labelpad=15)
393 title = fr
'$\mathbf{{\underline{{\large{{{s_lbl}}}}}}}$' +
'\n' + fr
'$\small{{\mathit{{N\small{{=}}}}{n_pts}}}$'
399 for plt_idx, (label, y_true)
in enumerate(zip(labels, y_test.T)):
400 if not plot_label[plt_idx]:
continue
402 product, title = label
403 plabel = f
'{product_labels[product]} {product_units[product]}'
405 for est_idx, est_lbl
in enumerate(plot_order[product]
if isinstance(plot_order, dict)
else plot_order):
407 if isinstance(plot_order, dict)
and est_lbl
not in benchmarks[product]:
408 axes[curr_idx].tick_params(labelcolor=
'none', top=
False, bottom=
False, left=
False, right=
False)
409 axes[curr_idx].axis(
'off')
413 y_est = benchmarks[product][est_lbl]
if isinstance(plot_order, dict)
else benchmarks[est_lbl][..., plt_idx]
415 cidx =
int(curr_idx / n_row)
419 first_row = (curr_idx % n_row) == 0
421 last_row = ((curr_idx+1) % n_row) == 0
422 first_col = (curr_idx % n_col) == 0
423 last_col = ((curr_idx+1) % n_col) == 0
424 print(curr_idx, first_row, last_row, first_col, last_col, est_lbl, product, plabel)
425 y_est_log = np.log10(y_est).flatten()
426 y_true_log = np.log10(y_true).flatten()
429 l_kws = {
'color': color,
'path_effects': [pe.Stroke(linewidth=4, foreground=
'k'), pe.Normal()],
'zorder': 22,
'lw': 1}
430 s_kws = {
'alpha': 0.4,
'color': color}
433 [i.set_linewidth(5)
for i
in ax.spines.values()]
437 est_lbl = est_lbl.replace(
'Mishra_',
'').replace(
'Gons_2band',
'Gons').replace(
'Gilerson_2band',
'GI2B').replace(
'Smith_',
'').replace(
'Cao_XGB',
'BST')
438 est_lbl = est_lbl.replace(
'QAA_CDOM',
'QAA\ CDOM')
439 if product
not in [
'chl',
'tss',
'cdom',
'pc']
and last_col:
441 ax2.tick_params(labelcolor=
'none', top=
False, bottom=
False, left=
False, right=
False, pad=0)
443 ax2.set_yticklabels([])
444 ax2.set_ylabel(fr
'$\mathbf{{{bands[plt_idx]:.0f}nm}}$', fontsize=22)
448 loc = ticker.LinearLocator(numticks=
int(round(maxv-minv+1.5)))
449 fmt = ticker.FuncFormatter(
lambda i, _:
r'$10$\textsuperscript{%i}'%i)
451 ax.set_ylim((minv, maxv))
452 ax.set_xlim((minv, maxv))
453 ax.xaxis.set_major_locator(loc)
454 ax.yaxis.set_major_locator(loc)
455 ax.xaxis.set_major_formatter(fmt)
456 ax.yaxis.set_major_formatter(fmt)
459 if not last_row: ax.set_xticklabels([])
462 elif isinstance(plot_order, dict):
463 ylabel = fr
'$\mathbf{{{y_pre}}}$'+
'' +fr
'$\mathbf{{ {plabel}}}$' +
'\n' + fr
'$\small{{\mathit{{N\small{{=}}}}{np.isfinite(y_true_log).sum()}}}$'
464 ax.set_xlabel(ylabel.replace(
' ',
'\ '), fontsize=fs)
466 valid = np.logical_and(np.isfinite(y_true_log), np.isfinite(y_est_log))
468 sns.regplot(y_true_log[valid], y_est_log[valid], ax=ax, scatter_kws=s_kws, line_kws=l_kws, fit_reg=
True, truncate=
False, robust=
True, ci=
None)
469 kde = sns.kdeplot(y_true_log[valid], y_est_log[valid], shade=
False, ax=ax, bw=
'scott', n_levels=4, legend=
False, gridsize=100, color=
'#555')
472 invalid = np.logical_and(np.isfinite(y_true_log), ~np.isfinite(y_est_log))
474 ax.scatter(y_true_log[invalid], [minv]*(invalid).sum(), color=
'r', alpha=0.4, label=
r'$\mathbf{%s\ invalid}$' % (invalid).sum())
475 ax.legend(loc=
'lower right', prop={
'weight':
'bold',
'size': 18})
480 add_stats_box(ax, y_true[valid], y_est[valid], metrics=[mdsa, sspb, slope], fontsize=18)
483 if first_row
or not plot_bands
or (isinstance(plot_order, dict)):
486 ax.set_title(
r'$\small{\textit{(Cao\ et\ al.\ 2020)}}$' +
'\n' + fr
'$\mathbf{{\large{{{est_lbl}}}}}$', fontsize=fs, linespacing=0.95)
488 elif est_lbl ==
'Ficek':
490 ax.set_title(fr
'$\mathbf{{\large{{{est_lbl}}}}}$' +
r'$\small{\textit{\ (et\ al.\ 2011)}}$', fontsize=fs, linespacing=0.95)
492 elif est_lbl ==
'Mannino':
494 ax.set_title(fr
'$\mathbf{{\large{{{est_lbl}}}}}$' +
r'$\small{\textit{\ (et\ al.\ 2008)}}$', fontsize=fs, linespacing=0.95)
496 elif est_lbl ==
'Novoa':
498 ax.set_title(fr
'$\mathbf{{\large{{{est_lbl}}}}}$' +
r'$\small{\textit{\ (et\ al.\ 2017)}}$', fontsize=fs, linespacing=0.95)
500 elif est_lbl ==
'GI2B':
501 ax.set_title(fr
'$\mathbf{{\large{{Gilerson}}}}$' +
r'$\small{\textit{\ (et\ al.\ 2010)}}$', fontsize=fs, linespacing=0.95)
503 elif est_lbl ==
'MDN': ax.set_title(fr
'$\mathbf{{{est_lbl}\ {product_labels[product]}}}$', fontsize=fs)
504 else: ax.set_title(fr
'$\mathbf{{\large{{{est_lbl}}}}}$', fontsize=fs)
506 ax.tick_params(labelsize=fs)
507 ax.grid(
'on', alpha=0.3)
509 u_label =
",".join([o.split(
'_')[0]
for o
in plot_order])
if len(plot_order) < 10
else f
'{n_row}x{n_col}'
510 filename = folder.joinpath(f
'{img_outlbl}{",".join(products)}_{sensor}_{n_pts}test_{u_label}.png')
513 plt.savefig(filename.as_posix(), dpi=100, bbox_inches=
'tight', pad_inches=0.1,)