1 from pathlib
import Path
2 from sklearn
import preprocessing
3 from tqdm
import trange
4 from collections
import defaultdict
as dd
10 from .model
import MDN
11 from .meta
import get_sensor_bands, SENSOR_LABEL, ANCILLARY, PERIODIC
12 from .utils
import get_labels, get_data, generate_config, using_feature, split_data, _load_datasets, compress
13 from .metrics
import performance, mdsa, sspb, msa
14 from .plot_utils
import plot_scatter
15 from .benchmarks
import run_benchmarks
16 from .parameters
import get_args
17 from .transformers
import TransformerPipeline, generate_scalers
20 def get_estimates(args, x_train=None, y_train=None, x_test=None, y_test=None, output_slices=None, dataset_labels=None, x_sim=None, y_sim=None, return_model=False, return_coefs=False):
22 Estimate all target variables for the given x_test. If a model doesn't
23 already exist, creates a model with the given training data.
29 print(f
'\nUsing {len(args.wavelengths)} wavelength(s) in the range [{args.wavelengths[0]}, {args.wavelengths[-1]}]')
36 if x_train
is not None: setattr(args,
'data_xtrain_shape', x_train.shape)
37 if y_train
is not None: setattr(args,
'data_ytrain_shape', y_train.shape)
38 if x_test
is not None: setattr(args,
'data_xtest_shape', x_test.shape)
39 if y_test
is not None: setattr(args,
'data_ytest_shape', y_test.shape)
40 if dataset_labels
is not None:
41 sets_str =
','.join(sorted(map(str, np.unique(dataset_labels))))
42 sets_hash = hashlib.sha256(sets_str.encode(
'utf-8')).hexdigest()
43 setattr(args,
'datasets_hash', sets_hash)
46 args.config_name = model_path.name
49 'avg_est' : getattr(args,
'avg_est',
False),
50 'threshold' : getattr(args,
'threshold',
None),
51 'confidence_interval' : getattr(args,
'CI',
None),
52 'use_gpu' : getattr(args,
'use_gpu',
False),
53 'chunk_size' : getattr(args,
'chunk_size', 1e4),
54 'return_coefs' :
True,
57 x_full, y_full = x_train, y_train
58 x_valid, y_valid =
None,
None
61 for round_num
in trange(args.n_rounds, disable=args.verbose
or (args.n_rounds == 1)
or args.silent):
62 args.curr_round = round_num
63 curr_round_seed = args.seed+round_num
if args.seed
is not None else None
64 np.random.seed(curr_round_seed)
67 if using_feature(args,
'bagging')
and x_train
is not None and args.n_rounds > 1:
68 (x_train, y_train), (x_valid, y_valid) =
split_data(x_full, y_full, n_train=0.75, seed=curr_round_seed)
70 datasets = {k: dict(zip([
'x',
'y'], v))
for k,v
in {
71 'train' : [x_train, y_train],
72 'valid' : [x_valid, y_valid],
73 'test' : [x_test, y_test],
74 'full' : [x_full, y_full],
75 'sim' : [x_sim, y_sim],
76 }.items()
if v[0]
is not None}
80 'hidden' : [args.n_hidden] * args.n_layers,
83 'n_iter' : args.n_iter,
85 'imputations': args.imputations,
86 'epsilon' : args.epsilon,
87 'scalerx' :
TransformerPipeline([S(*args, **kwargs)
for S, args, kwargs
in args.x_scalers]),
88 'scalery' :
TransformerPipeline([S(*args, **kwargs)
for S, args, kwargs
in args.y_scalers]),
89 'model_path' : model_path.joinpath(f
'Round_{round_num}'),
90 'no_load' : args.no_load,
91 'no_save' : args.no_save,
92 'seed' : curr_round_seed,
93 'verbose' : args.verbose,
96 model =
MDN(**model_kwargs)
97 model.fit(x_train, y_train, output_slices, args=args, datasets=datasets)
100 outputs[
'model'].append(model)
103 outputs[
'scalerx'].append(model.scalerx)
104 outputs[
'scalery'].append(model.scalery)
106 if x_test
is not None:
107 (estimates, *confidence), coefs = model.predict(x_test, **predict_kwargs)
108 outputs[
'estimates'].append(estimates)
111 outputs[
'coefs'].append(coefs)
114 upper, lower = confidence
115 outputs[
'upper_bound'].append(upper)
116 outputs[
'lower_bound'].append(lower)
118 if args.verbose
and y_test
is not None:
119 median = np.median(outputs[
'estimates'], axis=0)
120 labels =
get_labels(args.wavelengths, output_slices, n_out=y_test.shape[1])
121 for lbl, y1, y2
in zip(labels, y_test.T, median.T):
123 print(f
'--- Done round {round_num} ---\n')
125 if hasattr(model,
'session'): model.session.close()
130 if len(outputs) == 1:
131 outputs =
list(outputs.values())[0]
132 return outputs, model.output_slices
136 ''' Apply a model (defined by kwargs and default parameters) to x_test '''
137 args =
get_args(kwargs, use_cmdline=use_cmdline)
139 return np.median(preds, 0), idxs
142 def image_estimates(data, sensor=None, function=apply_model, rhos=False, anc=False, **kwargs):
144 Takes data of shape [Height, Width, Wavelengths] and returns the outputs of the
145 given function for that image, in the same [H, W] shape.
146 rhos and anc models are not yet available.
148 def ensure_feature_dim(v):
149 if len(v.shape) == 1:
153 if isinstance(data, list):
154 assert(all([data[0].shape == d.shape
for d
in data])), (
155 f
'Not all inputs have the same shape: {[d.shape for d in data]}')
156 data = np.dstack(data)
158 assert(sensor
is not None), (
159 f
'Must pass sensor name to image_estimates function. Options are: {list(SENSOR_LABEL.keys())}')
160 assert(sensor
in SENSOR_LABEL), (
161 f
'Requested sensor {sensor} unknown. Must be one of: {list(SENSOR_LABEL.keys())}')
162 assert(len(data.shape) == 3), (
163 f
'Expected data to have 3 dimensions (height, width, feature). Found shape: {data.shape}')
165 args =
get_args(sensor=sensor, **kwargs)
166 expected_features = len(
get_sensor_bands(sensor, args)) + (len(ANCILLARY)+len(PERIODIC)
if anc
or rhos
else 0)
167 assert(data.shape[-1] == expected_features), (
168 f
'Got {data.shape[-1]} features; expected {expected_features} features for sensor {sensor}')
170 im_shape = data.shape[:-1]
171 im_data = np.ma.masked_invalid(data.reshape((-1, data.shape[-1])))
172 im_mask = np.any(im_data.mask, axis=1)
173 im_data = im_data[~im_mask]
174 estimate = function(im_data, sensor=sensor, **kwargs)
if im_data.size
else np.zeros((0, 1))
178 if isinstance(estimate, tuple):
179 estimate, *remaining = estimate
181 estimate = ensure_feature_dim(estimate)
182 est_mask = np.tile(im_mask[:,
None], (1, estimate.shape[-1]))
183 est_data = np.ma.array(np.zeros(est_mask.shape)*np.nan, mask=est_mask, hard_mask=
True)
184 est_data.data[~im_mask] = estimate
185 est_data = est_data.reshape(im_shape + est_data.shape[-1:])
188 if remaining
is not None and len(remaining):
189 if len(remaining) == 1:
190 remaining = remaining[0]
191 return est_data, remaining
196 ''' Print datasets shape & min / max stats per feature '''
197 label = kwargs.pop(
'label',
'')
198 for k, arr
in kwargs.items():
200 print(f
'\n{label} {k.title()}'.strip()+
'\n\t'.join([
'']+[f
'{k}: {v}'.replace(
"'",
"")
for k, v
in {
201 'Shape' : np.array(arr).shape,
202 'N Valid' : getattr(np.isfinite(arr).sum(0),
'min' if np.array(arr).shape[1] > 10
else 'tolist')(),
203 'Minimum' : [f
'{a:>6.2f}' for a
in np.nanmin(arr, 0)],
204 'Maximum' : [f
'{a:>6.2f}' for a
in np.nanmax(arr, 0)],
207 if hasattr(arr,
'head'):
208 print(
'First sample:')
209 print(arr.head(1).to_string(index=
False),
'\n---------------------------\n')
213 estimates, slices =
get_estimates(args, x_train, y_train, x_test, y_test, slices)
214 estimates = np.median(estimates, 0)
215 benchmarks =
run_benchmarks(args.sensor, x_test, y_test, x_train, y_train, slices, args)
217 if p
not in benchmarks: benchmarks[p] = {}
218 benchmarks[p].
update({
'MDN' : estimates[..., slices[p]]})
227 filename = Path(args.filename)
228 assert(filename.exists()), f
'Expecting "{filename}" to be path to Rrs data, but it does not exist.'
231 if filename.is_file(): x_test = np.loadtxt(args.filename, delimiter=
',')
232 else: x_test, *_ = _load_datasets([
'Rrs'], [filename], bands)
234 print(f
'Generating estimates for {len(x_test)} data points ({x_test.shape})')
238 estimates = np.median(estimates, 0)
241 labels =
get_labels(bands, slices, estimates.shape[1])
242 estimates = np.append([labels], estimates, 0).astype(str)
243 filename = filename.parent.joinpath(f
'MDN_{filename.stem}.csv').as_posix()
245 print(f
'Saving estimates at location "{filename}"')
246 np.savetxt(filename, estimates, delimiter=
',', fmt=
'%s')
250 x_data, y_data, slices, locs =
get_data(args)
252 valid = np.any(np.isfinite(x_data), 1)
253 x_data = x_data[valid].astype(str)
254 y_data = y_data[valid].astype(str)
255 locs = np.array(locs)[valid].astype(str)
258 data_full = np.append(np.append(locs, y_data, 1), x_data, 1)
259 data_full = np.append([[
'dataset',
'index']+lbls+wvls], data_full, 0)
260 filename = f
'{args.sensor}_data_full.csv'
261 np.savetxt(filename, data_full, delimiter=
',', fmt=
'%s')
262 print(f
'Saved data with shape {data_full.shape} to {filename}')
267 if args.dataset ==
'sentinel_paper':
268 setattr(args,
'fix_tchl',
True)
269 setattr(args,
'seed', 1234)
271 np.random.seed(args.seed)
274 n_train = 0.5
if args.dataset !=
'sentinel_paper' else 1000
275 x_data, y_data, slices, locs =
get_data(args)
277 (x_train, y_train), (x_test, y_test) =
split_data(x_data, y_data, n_train=n_train, seed=args.seed)
278 print(x_train.shape, x_test.shape)
279 benchmarks =
generate_estimates(args, bands, x_train, y_train, x_test, y_test, slices, locs)
280 labels =
get_labels(bands, slices, y_test.shape[1])
281 products = args.product.split(
',')
282 plot_scatter(y_test, benchmarks, bands, labels, products, args.sensor)
286 x_data, y_data, slices, locs =
get_data(args)
287 get_estimates(args, x_data, y_data, output_slices=slices, dataset_labels=locs[:,0])