1 from .TrainingPlot
import TrainingPlot
2 from ..metrics
import mdsa, sspb
4 from tempfile
import TemporaryDirectory
5 from pathlib
import Path
7 import tensorflow
as tf
12 ''' Display a real-time training progress plot '''
15 super(PlottingCallback, self).
__init__()
32 ''' Save performance statistics as the model is trained '''
34 def __init__(self, args, data, mdn, metrics=[mdsa, sspb], folder='Results_gpu'):
35 super(StatsCallback, self).
__init__()
46 all_keys = sorted(self.
data.keys())
47 all_data = [self.
data[k][
'x']
for k
in all_keys]
48 all_sums = np.cumsum(
list(map(len, [[]] + all_data[:-1])))
49 all_idxs = [slice(c, len(d)+c)
for c,d
in zip(all_sums, all_data)]
50 all_data = np.vstack(all_data)
53 estimates = self.
mdn.predict(all_data)
54 estimates = {k: estimates[idxs]
for k, idxs
in zip(all_keys, all_idxs)}
55 assert(all([estimates[k].shape == self.
data[k][
'y'].shape
for k
in all_keys])), \
56 [(estimates[k].shape, self.
data[k][
'y'].shape)
for k
in all_keys]
58 save_folder = Path(self.
folder, self.
args.config_name).resolve()
59 if not save_folder.exists():
60 print(f
'\nSaving training results at {save_folder}\n')
61 save_folder.mkdir(parents=
True, exist_ok=
True)
64 round_stats_file = save_folder.joinpath(f
'round_{self.args.curr_round}.csv')
65 if not round_stats_file.exists()
or self.
_step_count == 0:
66 with round_stats_file.open(
'w+')
as fn:
67 fn.write(
','.join([
'iteration',
'cumulative_time'] + [f
'{k}_{m.__name__}' for k
in all_keys
for m
in self.
metrics]) +
'\n')
69 stats = [[
str(m(y1, y2))
for y1,y2
in zip(self.
data[k][
'y'].T, estimates[k].T)]
for k
in all_keys
for m
in self.
metrics]
70 stats =
','.join([f
'[{s}]' for s
in [
','.join(stat)
for stat
in stats]])
71 with round_stats_file.open(
'a+')
as fn:
72 fn.write(f
'{self._step_count},{time.time()-self.start_time},{stats}\n')
75 save_folder = save_folder.joinpath(
'Estimates')
76 if not save_folder.exists():
77 save_folder.mkdir(parents=
True, exist_ok=
True)
80 filename = save_folder.joinpath(f
'round_{self.args.curr_round}_{k}.csv')
81 if not filename.exists():
82 with filename.open(
'w+')
as fn:
83 fn.write(f
'target,{list(self.data[k]["y"][:,0])}\n')
85 with filename.open(
'a+')
as fn:
86 fn.write(f
'{self._step_count},{list(estimates[k][:,0])}\n')
92 ''' Save models during training, and load the best performing
93 on the validation set once training is completed.
98 Path(path).mkdir(exist_ok=
True, parents=
True)
101 super(ModelCheckpoint, self).
__init__(
102 filepath=self.
checkpoint, save_weights_only=
True,
103 monitor=
'val_MSA', mode=
'min', save_best_only=
True)
112 ''' Verify tf parameters are being decayed as they should;
113 call show_plot() on object once training is completed '''
120 self.
lr.append(self.model.optimizer.lr)
121 self.
wd.append(self.model.optimizer.weight_decay)
124 import matplotlib.pyplot
as plt
125 plt.plot(self.
lr, label=
'learning rate')
126 plt.plot(self.
wd, label=
'weight decay')
128 plt.ylabel(
'param value')