OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
product_estimation.py
Go to the documentation of this file.
1 from pathlib import Path
2 from sklearn import preprocessing
3 from tqdm import trange
4 from collections import defaultdict as dd
5 
6 import numpy as np
7 import pickle as pkl
8 import hashlib
9 
10 from .model import MDN
11 from .meta import get_sensor_bands, SENSOR_LABEL, ANCILLARY, PERIODIC
12 from .utils import get_labels, get_data, generate_config, using_feature, split_data, _load_datasets, compress
13 from .metrics import performance, mdsa, sspb, msa
14 from .plot_utils import plot_scatter
15 from .benchmarks import run_benchmarks
16 from .parameters import get_args
17 from .transformers import TransformerPipeline, generate_scalers
18 
19 
20 def get_estimates(args, x_train=None, y_train=None, x_test=None, y_test=None, output_slices=None, dataset_labels=None, x_sim=None, y_sim=None, return_model=False, return_coefs=False):
21  '''
22  Estimate all target variables for the given x_test. If a model doesn't
23  already exist, creates a model with the given training data.
24  '''
25  # Add x/y scalers to the args object
26  generate_scalers(args, x_train, x_test)
27 
28  if args.verbose:
29  print(f'\nUsing {len(args.wavelengths)} wavelength(s) in the range [{args.wavelengths[0]}, {args.wavelengths[-1]}]')
30  if x_train is not None: print_dataset_stats(x=x_train, label='Train')
31  if y_train is not None: print_dataset_stats(y=y_train, label='Train')
32  if x_test is not None: print_dataset_stats(x=x_test, label='Test')
33  if y_test is not None: print_dataset_stats(y=y_test, label='Test')
34 
35  # Add a few additional variables to be stored in the generated config file
36  if x_train is not None: setattr(args, 'data_xtrain_shape', x_train.shape)
37  if y_train is not None: setattr(args, 'data_ytrain_shape', y_train.shape)
38  if x_test is not None: setattr(args, 'data_xtest_shape', x_test.shape)
39  if y_test is not None: setattr(args, 'data_ytest_shape', y_test.shape)
40  if dataset_labels is not None:
41  sets_str = ','.join(sorted(map(str, np.unique(dataset_labels))))
42  sets_hash = hashlib.sha256(sets_str.encode('utf-8')).hexdigest()
43  setattr(args, 'datasets_hash', sets_hash)
44 
45  model_path = generate_config(args, create=x_train is not None)
46  args.config_name = model_path.name
47 
48  predict_kwargs = {
49  'avg_est' : getattr(args, 'avg_est', False),
50  'threshold' : getattr(args, 'threshold', None),
51  'confidence_interval' : getattr(args, 'CI', None),
52  'use_gpu' : getattr(args, 'use_gpu', False),
53  'chunk_size' : getattr(args, 'chunk_size', 1e4),
54  'return_coefs' : True,
55  }
56 
57  x_full, y_full = x_train, y_train
58  x_valid, y_valid = None, None
59 
60  outputs = dd(list)
61  for round_num in trange(args.n_rounds, disable=args.verbose or (args.n_rounds == 1) or args.silent):
62  args.curr_round = round_num
63  curr_round_seed = args.seed+round_num if args.seed is not None else None
64  np.random.seed(curr_round_seed)
65 
66  # 75% of rows used in bagging
67  if using_feature(args, 'bagging') and x_train is not None and args.n_rounds > 1:
68  (x_train, y_train), (x_valid, y_valid) = split_data(x_full, y_full, n_train=0.75, seed=curr_round_seed)
69 
70  datasets = {k: dict(zip(['x','y'], v)) for k,v in {
71  'train' : [x_train, y_train],
72  'valid' : [x_valid, y_valid],
73  'test' : [x_test, y_test],
74  'full' : [x_full, y_full],
75  'sim' : [x_sim, y_sim],
76  }.items() if v[0] is not None}
77 
78  model_kwargs = {
79  'n_mix' : args.n_mix,
80  'hidden' : [args.n_hidden] * args.n_layers,
81  'lr' : args.lr,
82  'l2' : args.l2,
83  'n_iter' : args.n_iter,
84  'batch' : args.batch,
85  'imputations': args.imputations,
86  'epsilon' : args.epsilon,
87  'scalerx' : TransformerPipeline([S(*args, **kwargs) for S, args, kwargs in args.x_scalers]),
88  'scalery' : TransformerPipeline([S(*args, **kwargs) for S, args, kwargs in args.y_scalers]),
89  'model_path' : model_path.joinpath(f'Round_{round_num}'),
90  'no_load' : args.no_load,
91  'no_save' : args.no_save,
92  'seed' : curr_round_seed,
93  'verbose' : args.verbose,
94  }
95 
96  model = MDN(**model_kwargs)
97  model.fit(x_train, y_train, output_slices, args=args, datasets=datasets)
98 
99  if return_model:
100  outputs['model'].append(model)
101 
102  if return_coefs:
103  outputs['scalerx'].append(model.scalerx)
104  outputs['scalery'].append(model.scalery)
105 
106  if x_test is not None:
107  (estimates, *confidence), coefs = model.predict(x_test, **predict_kwargs)
108  outputs['estimates'].append(estimates)
109 
110  if return_coefs:
111  outputs['coefs'].append(coefs)
112 
113  if len(confidence):
114  upper, lower = confidence
115  outputs['upper_bound'].append(upper)
116  outputs['lower_bound'].append(lower)
117 
118  if args.verbose and y_test is not None:
119  median = np.median(outputs['estimates'], axis=0)
120  labels = get_labels(args.wavelengths, output_slices, n_out=y_test.shape[1])
121  for lbl, y1, y2 in zip(labels, y_test.T, median.T):
122  print( performance(f'{lbl:>7s} Median', y1, y2) )
123  print(f'--- Done round {round_num} ---\n')
124 
125  if hasattr(model, 'session'): model.session.close()
126 
127  # Create compressed model archive
128  compress(model_path)
129 
130  if len(outputs) == 1:
131  outputs = list(outputs.values())[0]
132  return outputs, model.output_slices
133 
134 
135 def apply_model(x_test, use_cmdline=True, **kwargs):
136  ''' Apply a model (defined by kwargs and default parameters) to x_test '''
137  args = get_args(kwargs, use_cmdline=use_cmdline)
138  preds, idxs = get_estimates(args, x_test=x_test)
139  return np.median(preds, 0), idxs
140 
141 
142 def image_estimates(data, sensor=None, function=apply_model, rhos=False, anc=False, **kwargs):
143  '''
144  Takes data of shape [Height, Width, Wavelengths] and returns the outputs of the
145  given function for that image, in the same [H, W] shape.
146  rhos and anc models are not yet available.
147  '''
148  def ensure_feature_dim(v):
149  if len(v.shape) == 1:
150  v = v[:, None]
151  return v
152 
153  if isinstance(data, list):
154  assert(all([data[0].shape == d.shape for d in data])), (
155  f'Not all inputs have the same shape: {[d.shape for d in data]}')
156  data = np.dstack(data)
157 
158  assert(sensor is not None), (
159  f'Must pass sensor name to image_estimates function. Options are: {list(SENSOR_LABEL.keys())}')
160  assert(sensor in SENSOR_LABEL), (
161  f'Requested sensor {sensor} unknown. Must be one of: {list(SENSOR_LABEL.keys())}')
162  assert(len(data.shape) == 3), (
163  f'Expected data to have 3 dimensions (height, width, feature). Found shape: {data.shape}')
164 
165  args = get_args(sensor=sensor, **kwargs)
166  expected_features = len(get_sensor_bands(sensor, args)) + (len(ANCILLARY)+len(PERIODIC) if anc or rhos else 0)
167  assert(data.shape[-1] == expected_features), (
168  f'Got {data.shape[-1]} features; expected {expected_features} features for sensor {sensor}')
169 
170  im_shape = data.shape[:-1]
171  im_data = np.ma.masked_invalid(data.reshape((-1, data.shape[-1])))
172  im_mask = np.any(im_data.mask, axis=1)
173  im_data = im_data[~im_mask]
174  estimate = function(im_data, sensor=sensor, **kwargs) if im_data.size else np.zeros((0, 1))
175 
176  # Need to handle function which return extra information (e.g. a dictionary mapping output feature slices)
177  remaining = None
178  if isinstance(estimate, tuple):
179  estimate, *remaining = estimate
180 
181  estimate = ensure_feature_dim(estimate)
182  est_mask = np.tile(im_mask[:,None], (1, estimate.shape[-1]))
183  est_data = np.ma.array(np.zeros(est_mask.shape)*np.nan, mask=est_mask, hard_mask=True)
184  est_data.data[~im_mask] = estimate
185  est_data = est_data.reshape(im_shape + est_data.shape[-1:])
186 
187  # Let the user handle the extra information of the function they passed, if there was any
188  if remaining is not None and len(remaining):
189  if len(remaining) == 1:
190  remaining = remaining[0]
191  return est_data, remaining
192  return est_data
193 
194 
195 def print_dataset_stats(**kwargs):
196  ''' Print datasets shape & min / max stats per feature '''
197  label = kwargs.pop('label', '')
198  for k, arr in kwargs.items():
199  if arr is not None:
200  print(f'\n{label} {k.title()}'.strip()+'\n\t'.join(['']+[f'{k}: {v}'.replace("'", "") for k, v in {
201  'Shape' : np.array(arr).shape,
202  'N Valid' : getattr(np.isfinite(arr).sum(0), 'min' if np.array(arr).shape[1] > 10 else 'tolist')(),
203  'Minimum' : [f'{a:>6.2f}' for a in np.nanmin(arr, 0)],
204  'Maximum' : [f'{a:>6.2f}' for a in np.nanmax(arr, 0)],
205  }.items()]), '\n')
206 
207  if hasattr(arr, 'head'):
208  print('First sample:')
209  print(arr.head(1).to_string(index=False), '\n---------------------------\n')
210 
211 
212 def generate_estimates(args, bands, x_train, y_train, x_test, y_test, slices, locs=None):
213  estimates, slices = get_estimates(args, x_train, y_train, x_test, y_test, slices)
214  estimates = np.median(estimates, 0)
215  benchmarks = run_benchmarks(args.sensor, x_test, y_test, x_train, y_train, slices, args)
216  for p in slices:
217  if p not in benchmarks: benchmarks[p] = {}
218  benchmarks[p].update({'MDN' : estimates[..., slices[p]]})
219  return benchmarks
220 
221 
222 def main():
223  args = get_args()
224 
225  # If a file was given, estimate the product for the Rrs contained within
226  if args.filename:
227  filename = Path(args.filename)
228  assert(filename.exists()), f'Expecting "{filename}" to be path to Rrs data, but it does not exist.'
229 
230  bands = get_sensor_bands(args.sensor, args)
231  if filename.is_file(): x_test = np.loadtxt(args.filename, delimiter=',')
232  else: x_test, *_ = _load_datasets(['Rrs'], [filename], bands)
233 
234  print(f'Generating estimates for {len(x_test)} data points ({x_test.shape})')
235  print_dataset_stats(rrs=x_test, label='Input')
236 
237  estimates, slices = get_estimates(args, x_test=x_test)
238  estimates = np.median(estimates, 0)
239  print_dataset_stats(estimates=estimates, label='MDN')
240 
241  labels = get_labels(bands, slices, estimates.shape[1])
242  estimates = np.append([labels], estimates, 0).astype(str)
243  filename = filename.parent.joinpath(f'MDN_{filename.stem}.csv').as_posix()
244 
245  print(f'Saving estimates at location "{filename}"')
246  np.savetxt(filename, estimates, delimiter=',', fmt='%s')
247 
248  # Save data used with the given args
249  elif args.save_data:
250  x_data, y_data, slices, locs = get_data(args)
251 
252  valid = np.any(np.isfinite(x_data), 1) # Remove samples which are completely nan
253  x_data = x_data[valid].astype(str)
254  y_data = y_data[valid].astype(str)
255  locs = np.array(locs)[valid].astype(str)
256  wvls = list(get_sensor_bands(args.sensor, args).astype(int).astype(str))
257  lbls = get_labels(get_sensor_bands(args.sensor, args), slices, y_data.shape[1])
258  data_full = np.append(np.append(locs, y_data, 1), x_data, 1)
259  data_full = np.append([['dataset', 'index']+lbls+wvls], data_full, 0)
260  filename = f'{args.sensor}_data_full.csv'
261  np.savetxt(filename, data_full, delimiter=',', fmt='%s')
262  print(f'Saved data with shape {data_full.shape} to {filename}')
263 
264  # Train a model with partial data, and benchmark on remaining
265  elif args.benchmark:
266 
267  if args.dataset == 'sentinel_paper':
268  setattr(args, 'fix_tchl', True)
269  setattr(args, 'seed', 1234)
270 
271  np.random.seed(args.seed)
272 
273  bands = get_sensor_bands(args.sensor, args)
274  n_train = 0.5 if args.dataset != 'sentinel_paper' else 1000
275  x_data, y_data, slices, locs = get_data(args)
276 
277  (x_train, y_train), (x_test, y_test) = split_data(x_data, y_data, n_train=n_train, seed=args.seed)
278  print(x_train.shape, x_test.shape)
279  benchmarks = generate_estimates(args, bands, x_train, y_train, x_test, y_test, slices, locs)
280  labels = get_labels(bands, slices, y_test.shape[1])
281  products = args.product.split(',')
282  plot_scatter(y_test, benchmarks, bands, labels, products, args.sensor)
283 
284  # Otherwise, train a model with all data (if not already existing)
285  else:
286  x_data, y_data, slices, locs = get_data(args)
287  get_estimates(args, x_data, y_data, output_slices=slices, dataset_labels=locs[:,0])
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def compress(path, overwrite=False)
Definition: utils.py:95
def split_data(x_data, other_data=[], n_train=0.5, n_valid=0, seed=None, shuffle=True)
Definition: utils.py:176
def plot_scatter(y_test, benchmarks, bands, labels, products, sensor, title=None, methods=None, n_col=3, img_outlbl='')
Definition: plot_utils.py:283
def print_dataset_stats(**kwargs)
def get_sensor_bands(sensor, args=None)
Definition: meta.py:114
def using_feature(args, flag)
Definition: utils.py:158
def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False)
Definition: metrics.py:208
def image_estimates(data, sensor=None, function=apply_model, rhos=False, anc=False, **kwargs)
def apply_model(x_test, use_cmdline=True, **kwargs)
def get_data(args)
Definition: utils.py:616
def generate_estimates(args, bands, x_train, y_train, x_test, y_test, slices, locs=None)
def get_estimates(args, x_train=None, y_train=None, x_test=None, y_test=None, output_slices=None, dataset_labels=None, x_sim=None, y_sim=None, return_model=False, return_coefs=False)
def run_benchmarks(sensor, x_test, y_test=None, x_train=None, y_train=None, slices=None, args=None, *product='chl', bands=None, verbose=False, return_rs=True, return_ml=False, return_opt=False, kwargs_rs={}, kwargs_ml={}, kwargs_opt={})
Definition: __init__.py:51
def get_labels(wavelengths, slices, n_out=None)
Definition: utils.py:76
def get_args(kwargs={}, use_cmdline=True, **kwargs2)
Definition: parameters.py:100
def generate_config(args, create=True, verbose=True)
Definition: utils.py:293
def generate_scalers(args, x_train=None, x_test=None, column_bagging=False)
Definition: __init__.py:20