OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
callbacks.py
Go to the documentation of this file.
1 from .TrainingPlot import TrainingPlot
2 from ..metrics import mdsa, sspb
3 
4 from tempfile import TemporaryDirectory
5 from pathlib import Path
6 
7 import tensorflow as tf
8 import numpy as np
9 
10 
11 class PlottingCallback(tf.keras.callbacks.Callback):
12  ''' Display a real-time training progress plot '''
13 
14  def __init__(self, args, data, model):
15  super(PlottingCallback, self).__init__()
16  self._step_count = 0
17  self.args = args
18  self.TP = TrainingPlot(args, model, data)
19  self.TP.setup()
20 
21  def on_train_batch_end(self, batch, logs=None):
22  self._step_count += 1
23  if (self._step_count % (self.args.n_iter // self.args.n_redraws)) == 0:
24  self.TP.update()
25 
26  def on_train_end(self, *args, **kwargs):
27  self.TP.finish()
28 
29 
30 
31 class StatsCallback(tf.keras.callbacks.Callback):
32  ''' Save performance statistics as the model is trained '''
33 
34  def __init__(self, args, data, mdn, metrics=[mdsa, sspb], folder='Results_gpu'):
35  super(StatsCallback, self).__init__()
36  self._step_count = 0
37  self.start_time = time.time()
38  self.args = args
39  self.data = data
40  self.mdn = mdn
41  self.metrics = metrics
42  self.folder = folder
43 
44  def on_train_batch_end(self, batch, logs=None):
45  if (self._step_count % (self.args.n_iter // self.args.n_redraws)) == 0:
46  all_keys = sorted(self.data.keys())
47  all_data = [self.data[k]['x'] for k in all_keys]
48  all_sums = np.cumsum(list(map(len, [[]] + all_data[:-1])))
49  all_idxs = [slice(c, len(d)+c) for c,d in zip(all_sums, all_data)]
50  all_data = np.vstack(all_data)
51 
52  # Create all estimates, transform back into original units, then split back into the original datasets
53  estimates = self.mdn.predict(all_data)
54  estimates = {k: estimates[idxs] for k, idxs in zip(all_keys, all_idxs)}
55  assert(all([estimates[k].shape == self.data[k]['y'].shape for k in all_keys])), \
56  [(estimates[k].shape, self.data[k]['y'].shape) for k in all_keys]
57 
58  save_folder = Path(self.folder, self.args.config_name).resolve()
59  if not save_folder.exists():
60  print(f'\nSaving training results at {save_folder}\n')
61  save_folder.mkdir(parents=True, exist_ok=True)
62 
63  # Save overall dataset statistics
64  round_stats_file = save_folder.joinpath(f'round_{self.args.curr_round}.csv')
65  if not round_stats_file.exists() or self._step_count == 0:
66  with round_stats_file.open('w+') as fn:
67  fn.write(','.join(['iteration','cumulative_time'] + [f'{k}_{m.__name__}' for k in all_keys for m in self.metrics]) + '\n')
68 
69  stats = [[str(m(y1, y2)) for y1,y2 in zip(self.data[k]['y'].T, estimates[k].T)] for k in all_keys for m in self.metrics]
70  stats = ','.join([f'[{s}]' for s in [','.join(stat) for stat in stats]])
71  with round_stats_file.open('a+') as fn:
72  fn.write(f'{self._step_count},{time.time()-self.start_time},{stats}\n')
73 
74  # Save model estimates
75  save_folder = save_folder.joinpath('Estimates')
76  if not save_folder.exists():
77  save_folder.mkdir(parents=True, exist_ok=True)
78 
79  for k in all_keys:
80  filename = save_folder.joinpath(f'round_{self.args.curr_round}_{k}.csv')
81  if not filename.exists():
82  with filename.open('w+') as fn:
83  fn.write(f'target,{list(self.data[k]["y"][:,0])}\n')
84 
85  with filename.open('a+') as fn:
86  fn.write(f'{self._step_count},{list(estimates[k][:,0])}\n')
87  self._step_count += 1
88 
89 
90 
91 class ModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
92  ''' Save models during training, and load the best performing
93  on the validation set once training is completed.
94  Currently untested.
95  '''
96 
97  def __init__(self, path):
98  Path(path).mkdir(exist_ok=True, parents=True)
99  self.tmp_folder = TemporaryDirectory(dir=path)
100  self.checkpoint = Path(self.tmp_folder.name).joinpath('checkpoint')
101  super(ModelCheckpoint, self).__init__(
102  filepath=self.checkpoint, save_weights_only=True,
103  monitor='val_MSA', mode='min', save_best_only=True) # need to add to metrics
104 
105  def on_train_end(self, *args, **kwargs):
106  self.model.load_weights(self.checkpoint)
107  self.tmp_folder.cleanup()
108 
109 
110 
111 class DecayHistory(tf.keras.callbacks.Callback):
112  ''' Verify tf parameters are being decayed as they should;
113  call show_plot() on object once training is completed '''
114 
115  def on_train_begin(self, logs={}):
116  self.lr = []
117  self.wd = []
118 
119  def on_batch_end(self, batch, logs={}):
120  self.lr.append(self.model.optimizer.lr)
121  self.wd.append(self.model.optimizer.weight_decay)
122 
123  def show_plot(self):
124  import matplotlib.pyplot as plt
125  plt.plot(self.lr, label='learning rate')
126  plt.plot(self.wd, label='weight decay')
127  plt.xlabel('step')
128  plt.ylabel('param value')
129  plt.legend()
130  plt.show()
Definition: setup.py:1
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def __init__(self, args, data, model)
Definition: callbacks.py:14
def on_train_batch_end(self, batch, logs=None)
Definition: callbacks.py:44
def on_train_end(self, *args, **kwargs)
Definition: callbacks.py:105
def on_train_begin(self, logs={})
Definition: callbacks.py:115
def on_batch_end(self, batch, logs={})
Definition: callbacks.py:119
def on_train_batch_end(self, batch, logs=None)
Definition: callbacks.py:21
def on_train_end(self, *args, **kwargs)
Definition: callbacks.py:26
const char * str
Definition: l1c_msi.cpp:35
def __init__(self, args, data, mdn, metrics=[mdsa, sspb], folder='Results_gpu')
Definition: callbacks.py:34