OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
TrainingPlot.py
Go to the documentation of this file.
1 from ..utils import get_labels, line_messages, ignore_warnings
2 from ..meta import get_sensor_bands
3 from ..metrics import rmse, rmsle, mape, mae, leqznan, sspb, mdsa, performance
4 from ..plot_utils import add_identity
5 
6 from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec as GridSubplot
7 from matplotlib.patches import Ellipse
8 from collections import defaultdict as dd
9 from sklearn import preprocessing
10 from pathlib import Path
11 
12 import matplotlib.animation as animation
13 import matplotlib.pyplot as plt
14 import numpy as np
15 import os
16 
17 
19  def __init__(self, args, model, data):
20  self.args = args
21  self.model = model
22  self.data = data
23 
24  # Sample a limited number of training samples to plot
25  n_samples = len(self.data['train']['x'])
26  self._idx = np.random.choice(range(n_samples), min(n_samples, 10000), replace=False)
27 
28 
29  def setup(self):
30  self.train_test = np.append(self.data['train']['x_t'][self._idx], self.data['test']['x_t'], 0)
31  self.train_losses = dd(list)
32  self.test_losses = dd(list)
33  self.model_losses = []
34 
35  # Location of 0/-1 in the transformed space
36  self.zero_line = self.model.scalery.inverse_transform(np.zeros((1, self.data['train']['y_t'].shape[-1])))
37  self.neg_line = self.model.scalery.inverse_transform(np.zeros((1, self.data['train']['y_t'].shape[-1]))-1)
38 
39  if self.args.darktheme:
40  plt.style.use('dark_background')
41 
42  n_ext = 3 # extra rows, in addition to 1-1 scatter plots
43  n_col = min(5, self.data['test']['y'].shape[1])
44  n_row = n_ext + (n_col + n_col - 1) // n_col
45  fig = plt.figure(figsize=(5*n_col, 2*n_row))
46  meta = enumerate( GridSpec(n_row, 1, hspace=0.35) )
47  conts = [GridSubplot(1, 2 if i in [0, n_row-1, n_row-2] else n_col, subplot_spec=o, wspace=0.3 if i else 0.45) for i, o in meta]
48  axs = [plt.Subplot(fig, sub) for container in conts for sub in container]
49  axs = axs[:n_col+2] + axs[-4:]
50  [fig.add_subplot(ax) for ax in axs]
51 
52  self.axes = [ax.twinx() for ax in axs[:2]] + axs
53  self.labels = get_labels(get_sensor_bands(self.args.sensor, self.args), self.model.output_slices, n_col)[:n_col]
54 
55  plt.ion()
56  plt.show()
57  plt.pause(1e-9)
58 
59  if self.args.animate:
60  ani_path = Path('Animations')
61  ani_tmp = ani_path.joinpath('tmp')
62  ani_tmp.mkdir(parents=True, exist_ok=True)
63  list(map(os.remove, ani_tmp.glob('*.png'))) # Delete any prior run temporary animation files
64 
65  # '-tune zerolatency' fixes issue where firefox won't play the mp4
66  # '-vf pad=...' ensures height/width are divisible by 2 (required by .h264 - https://stackoverflow.com/questions/20847674/ffmpeg-libx264-height-not-divisible-by-2)
67  extra_args = ["-tune", "zerolatency", "-vf", "pad=width=ceil(iw/2)*2:height=ceil(ih/2)*2:color=white"]
68  ani_writer = self.ani_writer = animation.writers['ffmpeg_file'](fps=3, extra_args=extra_args)
69  ani_writer.setup(fig, ani_path.joinpath('MDN.mp4').as_posix(), dpi=100, frame_prefix=ani_tmp.joinpath('_').as_posix(), clear_temp=False)
70 
71  @ignore_warnings
72  def update(self, plot_metrics=[mdsa, rmsle]):
73  model = self.model
74  if hasattr(model, 'session'):
75  (prior, mu, sigma), est, avg = model.session.run([model.coefs, model.most_likely, model.avg_estimate], feed_dict={model.x: self.train_test})
76  train_loss = model.session.run(model.neg_log_pr, feed_dict={model.x: self.data['train']['x_t'][self._idx], model.y: self.data['train']['y_t'][self._idx]})
77  test_loss = model.session.run(model.neg_log_pr, feed_dict={model.x: self.data['test' ]['x_t'], model.y: self.data['test' ]['y_t']})
78  else:
79  # mix = model.model.layers[-1]
80  tt_out = model(self.train_test)
81  coefs = prior, mu, sigma = model.get_coefs(tt_out)
82  est = model._get_top_estimate(coefs).numpy()
83  avg = model._get_avg_estimate(coefs).numpy()
84  train_loss = model.loss(self.data['train']['y_t'][self._idx], model(self.data['train']['x_t'][self._idx])).numpy()
85  test_loss = model.loss(self.data['test' ]['y_t'], model(self.data['test' ]['x_t'])).numpy()
86  prior = prior.numpy()
87  mu = mu.numpy()
88  sigma = sigma.numpy()
89 
90 
91  est = model.scalery.inverse_transform(est)
92  avg = model.scalery.inverse_transform(avg)
93 
94  n_xtrain = len(self._idx)
95  train_est = est[:n_xtrain ]
96  train_avg = avg[:n_xtrain ]
97  test_est = est[ n_xtrain:]
98  test_avg = avg[ n_xtrain:]
99 
100  for metric in plot_metrics:
101  self.train_losses[metric.__name__].append([metric(y1, y2) for y1, y2 in zip(self.data['train']['y'][self._idx].T, train_est.T)])
102  self.test_losses[ metric.__name__].append([metric(y1, y2) for y1, y2 in zip(self.data['test' ]['y'].T, test_est.T)])
103 
104  self.model_losses.append([train_loss, leqznan(test_est), test_loss])
105  test_probs = np.max( prior, 1)[n_xtrain:]
106  test_mixes = np.argmax(prior, 1)[n_xtrain:]
107 
108  if model.verbose:
109  messages = zip( [performance( lbl, y1, y2) for lbl, y1, y2 in zip(self.labels, self.data['test']['y'].T, test_est.T)],
110  [performance('avg', y1, y2) for lbl, y1, y2 in zip(self.labels, self.data['test']['y'].T, test_avg.T)])
111  self.messages = [m for msg in messages for m in msg]
112  line_messages(self.messages, nbars=2)
113 
114  net_loss, zero_cnt, test_loss = np.array(self.model_losses).T
115  [ax.cla() for ax in self.axes]
116 
117  # Top two plots, showing training progress
118  for axi, (ax, metric) in enumerate(zip(self.axes[:len(plot_metrics)], plot_metrics)):
119  name = metric.__name__
120  line = ax.plot(np.array(self.train_losses[name]), ls='--', alpha=0.5)
121  ax.set_prop_cycle(plt.cycler('color', [l.get_color() for l in line]))
122  ax.plot(np.array(self.test_losses[name]), alpha=0.8)
123  ax.set_ylabel(metric.__name__, fontsize=8)
124  ax.set_yscale('log')
125 
126  if axi == 0:
127  n_targets = self.data['train']['y_t'].shape[1]
128  ax.legend(self.labels, bbox_to_anchor=(1.22, 1.1 + .1*(n_targets//6 + 1)),
129  ncol=min(6, n_targets), fontsize=8, loc='center', title='Training')
130 
131  axi = len(plot_metrics)
132  self.axes[axi].plot(net_loss, ls='--', color='gray')
133  self.axes[axi].plot(test_loss, color='w' if self.args.darktheme else 'k')
134  self.axes[axi].plot([np.argmin(test_loss)], [np.min(test_loss)], 'rx')
135  self.axes[axi].set_ylabel('Network Loss', fontsize=8)
136  self.axes[axi].tick_params(labelsize=8)
137  axi += 1
138 
139  self.axes[axi].plot(zero_cnt, ls='--', color='w' if self.args.darktheme else 'k')
140  self.axes[axi].set_ylabel('Est <= 0 Count', fontsize=8)
141  self.axes[axi].tick_params(labelsize=8)
142  axi += 1
143 
144  # Middle plots, showing 1-1 scatter plot estimates against measurements
145  for yidx, lbl in enumerate(self.labels):
146  ax = self.axes[axi]
147  axi += 1
148 
149  ax.scatter(self.data['test']['y'][:, yidx], test_est[:, yidx], 10, c=test_mixes/prior.shape[1], cmap='jet', alpha=.5, zorder=5)
150  ax.axhline(self.zero_line[0, yidx], ls='--', color='w' if self.args.darktheme else 'k', alpha=.5)
151  # ax.axhline(neg_line[0, yidx], ls='-.', color='w' if self.args.darktheme else 'k', alpha=.5)
152  add_identity(ax, ls='--', color='w' if self.args.darktheme else 'k', zorder=6)
153 
154  ax.tick_params(labelsize=5)
155  ax.set_title(lbl, fontsize=8)
156  ax.set_xscale('log')
157  ax.set_yscale('log')
158  minlim = max(min(self.data['test']['y'][:, yidx].min(), test_est[:, yidx].min()), 1e-6)
159  maxlim = min(max(self.data['test']['y'][:, yidx].max(), test_est[:, yidx].max()), 2000)
160 
161  if np.all(np.isfinite([minlim, maxlim])):
162  ax.set_ylim((minlim, maxlim))
163  ax.set_xlim((minlim, maxlim))
164 
165  if yidx == 0:#(yidx % n_col) == 0:
166  ax.set_ylabel('Estimate', fontsize=8)
167 
168  if yidx == 0:#(yidx // n_col) == (n_row-(n_ext+1)):
169  ax.set_xlabel('Measurement', fontsize=8)
170 
171  # Bottom plot showing likelihood
172  self.axes[axi].hist(test_probs)
173  self.axes[axi].set_xlabel('Likelihood')
174  self.axes[axi].set_ylabel('Frequency')
175  axi += 1
176 
177  self.axes[axi].hist(prior, stacked=True, bins=20)
178 
179  # Shows two dimensions of a few gaussians
180  # circle = Ellipse((valid_mu[0], valid_mu[-1]), valid_si[0], valid_si[-1])
181  # circle.set_alpha(.5)
182  # circle.set_facecolor('g')
183  # self.axes[axi].add_artist(circle)
184  # self.axes[axi].plot([valid_mu[0]], [valid_mu[-1]], 'r.')
185  # self.axes[axi].set_xlim((-2,2))#-min(valid_si[0], valid_si[-1]), max(valid_si[0], valid_si[-1])))
186  # self.axes[axi].set_ylim((-2,2))#-min(valid_si[0], valid_si[-1]), max(valid_si[0], valid_si[-1])))
187 
188  # Bottom plot meshing together all gaussians into a probability-weighted heatmap
189  # Sigmas are of questionable validity, due to data scaling interference
190 
191  axi += 1
192  KEY = list(model.output_slices.keys())[0]
193  IDX = model.output_slices[KEY].start
194  sigma = sigma[n_xtrain:, ...]
195  sigma = model.scalery.inverse_transform(sigma.diagonal(0, -2, -1).reshape((-1, mu.shape[-1]))).reshape((sigma.shape[0], -1, sigma.shape[-1]))[..., IDX][None, ...]
196  mu = mu[n_xtrain:, ...]
197  mu = model.scalery.inverse_transform(mu.reshape((-1, mu.shape[-1]))).reshape((mu.shape[0], -1, mu.shape[-1]))[..., IDX][None, ...]
198  prior = prior[None, n_xtrain:]
199 
200  Y = np.logspace(np.log10(self.data['test']['y'][:, IDX].min()*.5), np.log10(self.data['test']['y'][:, IDX].max()*1.5), 100)[::-1, None, None]
201  var = 2 * sigma ** 2
202  num = np.exp(-(Y - mu) ** 2 / var)
203  Z = (prior * (num / (np.pi * var) ** 0.5))
204  I,J = np.ogrid[:Z.shape[0], :Z.shape[1]]
205  mpr = np.argmax(prior, 2)
206  Ztop= Z[I, J, mpr]
207  Z[I, J, mpr] = 0
208  Z = Z.sum(2)
209  Ztop += 1e-5
210  Z /= Ztop.sum(0)
211  Ztop /= Ztop.sum(0)
212 
213  zp = prior.copy()
214  I,J = np.ogrid[:zp.shape[0], :zp.shape[1]]
215  zp[I,J,mpr] = 0
216  zp = zp.sum(2)[0]
217  Z[Z < (Z.max(0)*0.9)] = 0
218  Z = Z.T
219  zi = zp < 0.2
220  Z[zi] = np.array([np.nan]*Z.shape[1])
221  Z = Z.T
222  Z[Z == 0] = np.nan
223 
224  ydxs, ysort = np.array(sorted(enumerate(self.data['test']['y'][:, IDX]), key=lambda v:v[1])).T
225  Z = Z[:, ydxs.astype(np.int32)]
226  Ztop = Ztop[:, ydxs.astype(np.int32)]
227 
228  if np.any(np.isfinite(Ztop)):
229  self.axes[axi].pcolormesh(np.arange(Z.shape[1]),Y,
230  preprocessing.MinMaxScaler((0,1)).fit_transform(Ztop), cmap='inferno', shading='gouraud')
231  if np.any(np.isfinite(Z)):
232  self.axes[axi].pcolormesh(np.arange(Z.shape[1]),Y, Z, cmap='BuGn_r', shading='gouraud', alpha=0.7)
233  # self.axes[axi].colorbar()
234  # self.axes[axi].set_yscale('symlog', linthreshy=y_valid[:, IDX].min()*.5)
235  self.axes[axi].set_yscale('log')
236  self.axes[axi].plot(ysort)#, color='red')
237  self.axes[axi].set_ylabel(KEY)
238  self.axes[axi].set_xlabel('in situ index (sorted by %s)' % KEY)
239  axi += 1
240 
241  # Same as last plot, but only show the 20 most uncertain samples
242  pc = prior[0, ydxs.astype(np.int32)]
243  pidx = np.argsort(pc.max(1))
244  pidx = np.sort(pidx[:20])
245  Z = Z[:, pidx]
246  Ztop = Ztop[:, pidx]
247  if np.any(np.isfinite(Ztop)):
248  self.axes[axi].pcolormesh(np.arange(Z.shape[1]),Y,
249  preprocessing.MinMaxScaler((0,1)).fit_transform(Ztop), cmap='inferno')
250  if np.any(np.isfinite(Z)):
251  self.axes[axi].pcolormesh(np.arange(Z.shape[1]),Y, Z, cmap='BuGn_r', alpha=0.7)
252 
253  self.axes[axi].set_yscale('log')
254  self.axes[axi].plot(ysort[pidx])#, color='red')
255  self.axes[axi].set_ylabel(KEY)
256  self.axes[axi].set_xlabel('in situ index (sorted by %s)' % KEY)
257 
258  plt.pause(1e-9)
259 
260  # Store the current plot as a frame for the animation
261  if len(self.model_losses) > 1 and self.args.animate:
262  ani_writer.grab_frame()
263 
264  if ((len(self.model_losses) % 5) == 0) or ((i+1) == int(self.args.n_iter)):
265  ani_writer._run()
266 
267 
268  def finish(self):
269  if self.args.animate:
270  ani_writer.finish()
271  # input('continue?')
272  plt.ioff()
273  plt.close()
274 
275  if self.model.verbose:
276  print('\n' * (len(self.messages) + 1))
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def get_sensor_bands(sensor, args=None)
Definition: meta.py:114
def leqznan(y, y_hat=None)
Definition: metrics.py:119
def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False)
Definition: metrics.py:208
def __init__(self, args, model, data)
Definition: TrainingPlot.py:19
def update(self, plot_metrics=[mdsa, rmsle])
Definition: TrainingPlot.py:72
def add_identity(ax, *line_args, **line_kwargs)
Definition: plot_utils.py:9
def line_messages(messages, nbars=1)
Definition: utils.py:53
def get_labels(wavelengths, slices, n_out=None)
Definition: utils.py:76