1 from .meta
import get_sensor_bands, ANCILLARY, PERIODIC
2 from .parameters
import update, hypers, flags, get_args
3 from .__version__
import __version__
5 from collections
import defaultdict
as dd
6 from importlib
import import_module
7 from datetime
import datetime
as dt
8 from pathlib
import Path
9 from tqdm
import trange
13 import hashlib, re, warnings, functools, sys, zipfile
17 ''' Decorator to silence all warnings (Runtime, User, Deprecation, etc.) '''
18 @functools.wraps(func)
19 def helper(*args, **kwargs):
20 with warnings.catch_warnings():
21 warnings.filterwarnings(
'ignore')
22 return func(*args, **kwargs)
27 ''' Index of closest wavelength '''
28 waves = np.array(waves)
30 i = np.abs(waves - w[:,
None]).argmin(1)
31 assert(
not validate
or (np.abs(w-waves[i]).max() <= tol)), f
'Needed {k}, but closest was {waves[i]} in {waves} ({np.abs(w-waves[i]).max()} > {tol})'
32 return i.reshape(np.array(k).shape)
36 ''' Value of closest wavelength '''
37 waves = np.array(waves)
42 ''' Parse int if possible, and return None otherwise '''
48 ''' Get all wavelengths associated with the given key, available within the netcdf '''
49 wvl = [
safe_int(v.replace(key,
''))
for v
in nc_data.variables.keys()
if key
in v]
50 return np.array(sorted([w
for w
in wvl
if w
is not None]))
55 Allow multiline message updates via tqdm.
56 Need to call print() after the tqdm loop,
57 equal to the number of messages which were
58 printed via this function (to reset cursor).
60 nbars is the number of tqdm bars the line
66 for j in trange(5, leave=False):
67 messages = [i, i/2, i*2]
68 line_messages(messages, nbars)
69 for _ in range(len(messages) + nbars - 1): print()
71 for _
in range(nbars): print()
72 for m
in messages: print(
'\033[K' +
str(m))
73 sys.stdout.write(
'\x1b[A'.join([
''] * (nbars + len(messages) + 1)))
78 Helper to get label for each target output. Assumes
79 that any variable in <slices> which has more than a
80 single slice index, will have an associated wavelength
84 wavelengths = [443, 483, 561, 655]
85 slices = {'bbp':slice(0,4), 'chl':slice(4,5), 'tss':slice(5,6)}
87 labels = get_labels(wavelengths, slices, n_out)
88 # labels -> ['bbp443', 'bbp483', 'bbp561', 'bbp655', 'chl']
90 return [k + (f
'{wavelengths[i]:.0f}' if (v.stop - v.start) > 1
else '')
91 for k,v
in sorted(slices.items(), key=
lambda s: s[1].start)
92 for i
in range(v.stop - v.start)][:n_out]
96 ''' Compress a folder into a .zip archive '''
97 if overwrite
or not path.with_suffix(
'.zip').exists():
98 with zipfile.ZipFile(path.with_suffix(
'.zip'),
'w', zipfile.ZIP_DEFLATED)
as zf:
99 for item
in path.rglob(
'*'):
100 zf.write(item, item.relative_to(path))
104 ''' Uncompress a .zip archive '''
105 if overwrite
or not path.exists():
106 if path.with_suffix(
'.zip').exists():
107 with zipfile.ZipFile(path.with_suffix(
'.zip'),
'r')
as zf:
112 ''' Ensure the classes are found, without requiring an import '''
113 _transformers = [p.stem
for p
in Path(__file__).parent.joinpath(
'transformers').glob(
'*Transformer.py')]
120 if name
in [
'WindowsPath',
'PosixPath']:
124 module = Path(__file__).parent.stem
125 imported = import_module(f
'{module}.transformers.{name}')
126 return getattr(imported, name)
128 elif name ==
'TransformerPipeline':
129 from .transformers
import TransformerPipeline
130 return TransformerPipeline
135 ''' Helper to write pickle file '''
136 with Path(filename).open(
'wb')
as f:
141 ''' Helper to read pickle file '''
142 with Path(filename).open(
'rb')
as f:
146 ''' Decorator for caching function outputs '''
147 path = Path(filename)
149 def wrapper(function):
150 def inner(*args, **kwargs):
151 if not recache
and path.exists():
153 return store_pkl(path, function(*args, **kwargs))
160 Certain hyperparameter flags have a yet undecided default value,
161 which means there are two possible names: using the feature, or
162 not using it. This method simply combines both into a single
163 boolean signal, which indicates whether to add the feature.
165 use_flag = hasattr(args, 'use_ratio') and args.use_ratio
166 no_flag = hasattr(args, 'no_ratio') and not args.no_ratio
167 signal = use_flag or no_flag # if true, we add ratios
169 signal = using_feature(args, 'ratio') # if true, we add ratios
171 flag = flag.replace(
'use_',
'').replace(
'no_',
'')
172 assert(hasattr(args,f
'use_{flag}')
or hasattr(args, f
'no_{flag}')), f
'"{flag}" flag not found'
173 return getattr(args, f
'use_{flag}',
False)
or not getattr(args, f
'no_{flag}',
True)
176 def split_data(x_data, other_data=[], n_train=0.5, n_valid=0, seed=None, shuffle=True):
178 Split the given data into training, validation, and testing
179 subsets, randomly shuffling the original data order.
181 if not isinstance(other_data, list): other_data = [other_data]
183 data = [d.iloc
if hasattr(d,
'iloc')
else d
for d
in [x_data] + other_data]
184 random = np.random.RandomState(seed)
185 idxs = np.arange(len(x_data))
186 if shuffle: random.shuffle(idxs)
189 if 0 < n_train <= 1: n_train =
int(n_train * len(idxs))
190 if 0 < n_valid <= 1: n_valid =
int(n_valid * len(idxs))
191 assert((n_train+n_valid) <= len(x_data)), \
192 'Too many training/validation samples requested: {n_train}, {n_valid} ({len(x_data)} available)'
194 train = [d[ idxs[:n_train] ]
for d
in data]
195 valid = [d[ idxs[n_train:n_valid+n_train] ]
for d
in data]
196 test = [d[ idxs[n_train+n_valid:] ]
for d
in data]
207 return train, valid, test
211 def mask_land(data, bands, threshold=0.1, verbose=False):
212 ''' Modified Normalized Difference Water Index, or NDVI if 1500nm+ is not available '''
218 b1, b2 = (green, swir)
if swir > 1500
else (red, nir)
if red != nir
else (min(bands), max(bands))
220 n_diff =
lambda a, b: np.ma.masked_invalid((a-b) / (a+b))
221 if verbose: print(f
'Using bands {b1} & {b2} for land masking')
222 return n_diff(data[..., i1], data[..., i2]).filled(fill_value=threshold-1) <= threshold
226 def _get_tile_wavelengths(nc_data, key, sensor, allow_neg=True, landmask=False, args=None):
227 ''' Return the Rrs/rhos data within the netcdf file, for wavelengths of the given sensor '''
228 has_key =
lambda k: any([k
in v
for v
in nc_data.variables])
229 wvl_key = f
'{key}_' if has_key(f
'{key}_')
or key !=
'Rrs' else 'Rw'
232 avail =
get_wvl(nc_data, wvl_key)
234 div = np.pi
if wvl_key ==
'Rw' else 1
235 data = np.ma.stack([nc_data[f
'{wvl_key}{b}'][:] / div
for b
in bands], axis=-1)
237 if not allow_neg: data[data <= 0] = np.nan
238 if landmask: data[
mask_land(data, bands) ] = np.nan
240 return bands, data.filled(fill_value=np.nan)
241 return [], np.array([])
243 def get_tile_data(filenames, sensor, allow_neg=True, rhos=False, anc=False, **kwargs):
244 ''' Gather the correct Rrs/rhos bands from a given scene, as well as ancillary features if necessary '''
245 from netCDF4
import Dataset
247 filenames = np.atleast_1d(filenames)
248 features = [
'rhos' if rhos
else 'Rrs'] + (ANCILLARY
if anc
or rhos
else [])
253 if rhos
and '-rho' not in sensor: sensor +=
'-rho'
255 args =
get_args(sensor=sensor, **kwargs)
256 for filename
in filenames:
257 with Dataset(filename,
'r')
as nc_data:
258 if 'geophysical_data' in nc_data.groups.keys():
259 nc_data = nc_data[
'geophysical_data']
261 for feature
in features:
262 if feature
not in data:
263 if feature
in [
'Rrs',
'rhos']:
264 bands, band_data = _get_tile_wavelengths(nc_data, feature, sensor, allow_neg, landmask=rhos, args=args)
267 assert(len(band_data.shape) == 3), \
268 f
'Different shape than expected: {band_data.shape}'
269 data[feature] = band_data
271 elif feature
in nc_data.variables:
272 var = nc_data[feature][:]
273 assert(len(var.shape) == 2), f
'Different shape than expected: {var.shape}'
275 if feature
in PERIODIC:
276 assert(var.min() >= -180
and var.max() <= 180), \
277 f
'Need to adjust transformation for variables not within [-180,180]: {feature}=[{var.min()}, {var.max()}]'
278 data[feature] = np.stack([
279 np.sin(2*np.pi*(var+180)/360),
280 np.cos(2*np.pi*(var+180)/360),
282 else: data[feature] = var
285 if 'time_diff' in features:
286 assert(features[0]
in data), f
'Missing {features[0]} data: {list(data.keys())}'
287 data[
'time_diff'] = np.zeros_like(data[features[0]][:, :, 0])
289 assert(len(data) == len(features)), f
'Missing features: Found {list(data.keys())}, Expecting {features}'
290 return bands, np.dstack([data[f]
for f
in features])
295 Create a config file for the current settings, and store in
296 a folder location determined by certain parameters:
297 MDN/model_loc/sensor/model_lbl/model_uid/config
298 "model_uid" is computed within this function, but a value can
299 also be passed in manually via args.model_uid in order to allow
300 previous MDN versions to run.
302 root = Path(__file__).parent.resolve().joinpath(args.model_loc, args.sensor, args.model_lbl)
305 if hasattr(args,
'model_uid'):
306 if args.verbose: print(f
'Using manually set model uid: {args.model_uid}')
307 return root.joinpath(args.model_uid)
310 dependents = [getattr(act,
'dest',
'')
for group
in [hypers, update]
for act
in group._group_actions]
311 dependents+= [
'x_scalers',
'y_scalers']
317 partials = [getattr(act,
'dest',
'')
for group
in [flags]
for act
in group._group_actions]
319 config = [f
'Version: {__version__}',
'',
'Dependencies']
320 config+= [
''.join([
'-']*len(config[-1]))]
321 others = [
'',
'Configuration']
322 others+= [
''.join([
'-']*len(others[-1]))]
324 for k,v
in sorted(args.__dict__.items(), key=
lambda z: z[0]):
325 if k
in [
'x_scalers',
'y_scalers']:
326 cinfo =
lambda s, sarg, skw: getattr(s,
'config_info',
lambda *a, **k:
'')(*sarg, **skw)
327 cfmt =
lambda *cargs: f
' # {cinfo(*cargs)}' if cinfo(*cargs)
else ''
328 v =
'\n\t' +
'\n\t'.join([f
'{(s[0].__name__,) + s[1:]}{cfmt(*s)}' for s
in v])
331 config.append(f
'{k:<18}: {v}')
332 elif k
in dependents: config.append(f
'{k:<18}: {v}')
333 else: others.append(f
'{k:<18}: {v}')
335 config =
'\n'.join(config)
336 others =
'\n'.join(others)
337 ver_re =
r'(Version\: \d+\.\d+)(?:\.\d+\n)'
338 h_str = re.sub(ver_re,
r'\1.0\n', config)
339 uid = hashlib.sha256(h_str.encode(
'utf-8')).hexdigest()
340 folder = root.joinpath(uid)
341 c_file = folder.joinpath(
'config')
345 print(f
'Using model path {folder}')
348 folder.mkdir(parents=
True, exist_ok=
True)
350 if not c_file.exists():
351 with c_file.open(
'w+')
as f:
352 f.write(f
'Created: {dt.now()}\n{config}\n{others}')
353 elif not c_file.exists()
and verbose:
354 print(
'\nCould not find config file with the following parameters:')
355 print(
'\t'+config.replace(
'\n',
'\n\t'),
'\n')
359 def _load_datasets(keys, locs, wavelengths, allow_missing=False):
361 Load data from [<locs>] using <keys> as the columns.
362 Only loads data which has all the bands defined by
363 <wavelengths> (if necessary, e.g. for Rrs or bbp).
364 First key is assumed to be the x_data, remaining keys
366 - allow_missing=True will allow datasets which are missing bands
367 to be included in the returned data
370 # Here, data/loc/Rrs.csv, data/loc/Rrs_wvl.csv, data/loc/bbp.csv,
371 # and data/chl.csv all exist, with the correct wavelengths available
372 # for Rrs and bbp (which is determined by Rrs_wvl.csv)
373 keys = ['Rrs', 'bbp', '../chl']
375 wavelengths = [443, 483, 561, 655]
376 _load_datasets(keys, locs, wavelengths) # -> [Rrs443, Rrs483, Rrs561, Rrs665],
377 [bbp443, bbp483, bbp561, bbp655, chl],
378 {'bbp':slice(0,4), 'chl':slice(4,5)}
380 def loadtxt(name, loc, required_wvl):
381 ''' Error handling wrapper over np.loadtxt, with the addition of wavelength selection'''
382 dloc = Path(loc).joinpath(f
'{name}.csv')
385 if 'tss' in name
and not dloc.exists():
386 dloc = Path(loc).joinpath(f
'{name.replace("tss","tsm")}.csv')
388 if not dloc.exists():
389 dloc = Path(loc).joinpath(f
'{name.replace("tsm","spm")}.csv')
392 if 'cdom' in name
and not dloc.exists():
393 dloc = Path(loc).joinpath(
'ag.csv')
397 required_wvl = np.array(required_wvl).flatten()
398 assert(dloc.exists()), (f
'Key {name} does not exist at {loc} ({dloc})')
400 data = np.loadtxt(dloc, delimiter=
',', dtype=float
if name
not in [
'../Dataset',
'../meta',
'../datetime']
else str, comments=
None)
401 if len(data.shape) == 1: data = data[:,
None]
403 if data.shape[1] > 1
and data.dtype.type
is not np.str_:
407 new_data = [[np.nan]*len(data)] * len(required_wvl)
408 wvls = np.loadtxt(Path(loc).joinpath(f
'{dloc.stem}_wvl.csv'), delimiter=
',')[:,
None]
409 idxs = np.abs(wvls - np.atleast_2d(required_wvl)).argmin(0)
410 valid = np.abs(wvls - np.atleast_2d(required_wvl)).min(0) < 2
412 for j, (i, v)
in enumerate(zip(idxs, valid)):
413 if v: new_data[j] = data[:, i]
414 data = np.array(new_data).T
416 data = data[:, get_valid(dloc.stem, loc, required_wvl)]
418 if 'cdom' in name
and dloc.stem ==
'ag':
421 except Exception
as e:
422 if name
not in [
'Rrs']:
424 print(f
'\n\tError fetching {name} from {loc}:\n{e}')
425 return np.array([]).reshape((0,0))
428 def get_valid(name, loc, required_wvl, margin=2):
429 ''' Dataset at <loc> must have all bands in <required_wvl> within <margin>nm '''
430 if 'HYPER' in str(loc): margin=1
433 wvls = np.loadtxt(Path(loc).joinpath(f
'{name}_wvl.csv'), delimiter=
',')[:,
None]
434 check = np.array([np.abs(wvls-w).min() <= margin
for w
in required_wvl])
435 assert(check.all()),
'\n\t\t'.join([
436 f
'{name} is missing {(~check).sum()} wavelengths:',
437 f
'Needed {required_wvl}', f
'Found {wvls.flatten()}',
438 f
'Missing {required_wvl[~check]}',
''])
441 valid = np.array([
True] * len(required_wvl))
442 if len(wvls) != len(required_wvl):
443 valid = np.abs(wvls - np.atleast_2d(required_wvl)).min(1) <= margin
444 assert(valid.sum() == len(required_wvl)), [wvls[valid].flatten(), required_wvl]
447 if not all([w1 == w2
for w1,w2
in zip(wvls[valid], required_wvl)]):
448 valid = [np.abs(wvls.flatten() - w).argmin()
for w
in required_wvl]
449 assert(len(np.unique(valid)) == len(valid) == len(required_wvl)), [valid, wvls[valid].flatten(), required_wvl]
452 locs = [Path(loc).resolve()
for loc
in np.atleast_1d(locs)]
453 print(
'\n-------------------------')
454 print(f
'Loading data for sensor {locs[0].parts[-1]}, and targets {[v.replace("../","") for v in keys[1:]]}')
456 print(
'Allowing data regardless of whether all bands exist')
463 loc_data = [
loadtxt(key, loc, wavelengths)
for key
in keys]
464 print(f
'\tN={len(loc_data[0]):>5} | {loc.parts[-1]} / {loc.parts[-2]} ({[np.isfinite(ld).all(1).sum() if ld.dtype.type is not np.str_ else len(ld) for ld in loc_data[1:]]})')
465 assert(all([len(l)
in [len(loc_data[0]), 0]
for l
in loc_data])), dict(zip(keys, map(np.shape, loc_data)))
467 if all([l.shape[1] == 0
for l
in loc_data[(1
if len(loc_data) > 1
else 0):]]):
468 print(f
'Skipping dataset {loc}: missing all features')
471 x_data += [loc_data.pop(0)]
473 l_data +=
list(zip([loc.parent.name] * len(x_data[-1]), np.arange(len(x_data[-1]))))
475 except Exception
as e:
478 print(f
'\nError fetching {loc}:\n\t{e}')
479 if len(np.atleast_1d(locs)) == 1:
482 assert(len(x_data) > 0
or len(locs) == 0),
'No datasets are valid with the given wavelengths'
483 assert(all([x.shape[1] == x_data[0].shape[1]
for x
in x_data])), f
'Differing number of {keys[0]} wavelengths: {[x.shape for x in x_data]}'
487 for i, key
in enumerate(keys[1:]):
488 shapes = [y[i].shape[1]
for y
in y_data]
489 slices.append(max(shapes))
491 for x, y
in zip(x_data, y_data):
492 if y[i].shape[1] == 0:
493 y[i] = np.full((x.shape[0], max(shapes)), np.nan)
494 assert(all([y[i].shape[1] == y_data[0][i].shape[1]
for y
in y_data])), f
'{key} shape mismatch: {[y.shape for y in y_data]}'
498 for i, s
in enumerate(slices):
500 print(f
'Dropping {keys[i+1]}: feature has no samples available')
503 slices = np.cumsum([0] + [s
for i,s
in enumerate(slices)
if i
not in drop])
504 keys = [k
for i,k
in enumerate(keys[1:])
if i
not in drop]
506 y = [z
for i,z
in enumerate(y)
if i
not in drop]
509 l_data = np.vstack(l_data)
510 x_data = np.vstack(x_data)
513 y_data = np.vstack([np.hstack(y)
for y
in y_data])
514 assert(slices[-1] == y_data.shape[1]), [slices, y_data.shape]
515 assert(y_data.shape[0] == x_data.shape[0]), [x_data.shape, y_data.shape]
516 slices = {k.replace(
'../',
'') : slice(slices[i], s)
for i,(k,s)
in enumerate(zip(keys, slices[1:]))}
517 print(f
'\tTotal prior to filtering: {len(x_data)}')
520 for product
in [
'ad',
'ag']:
521 if product
in slices:
522 from .metrics
import mdsa
523 from scipy.optimize
import curve_fit
525 exponential =
lambda x, a, b, c: a * np.exp(-b*x) + c
526 remove = np.zeros_like(y_data[:,0]).astype(bool)
528 for i, sample
in enumerate(y_data):
529 sample = sample[slices[product]]
530 assert(len(sample) > 5), f
'Number of bands should be larger, when fitting exponential: {product}, {sample.shape}'
531 assert(len(sample) == len(wavelengths)), f
'Sample size / wavelengths mismatch: {len(sample)} vs {len(wavelengths)}'
533 if np.all(np.isfinite(sample))
and np.min(sample) > -0.1:
535 x = np.array(wavelengths) - np.min(wavelengths)
536 params, _ = curve_fit(exponential, x, sample, bounds=((1e-3, 1e-3, 0), (1e2, 1e0, 1e1)))
537 new_sample = exponential(x, *params)
540 if mdsa(sample[
None,:], new_sample[
None,:]) < 10:
541 y_data[i, slices[product]] = new_sample
542 else: remove[i] =
True
543 except: remove[i] =
True
547 x_data[remove] = np.nan
548 y_data[remove] = np.nan
549 l_data[remove] = np.nan
552 print(f
'Removed {remove.sum()} / {len(remove)} samples due to poor quality {product} spectra')
553 assert((~remove).sum()), f
'All data removed due to {product} spectra quality...'
555 return x_data, y_data, slices, l_data
558 def _filter_invalid(x_data, y_data, slices, allow_nan_inp=False, allow_nan_out=False, other=[]):
560 Filter the given data to only include samples which are valid. By
561 default, valid samples include all which are not nan, and greater
562 than zero (for all target features).
563 - allow_nan_inp=True can be set to allow a sample as valid if _any_
564 of a sample's input x features are not nan and greater than zero.
565 - allow_nan_out=True can be set to allow a sample as valid if _any_
566 of a sample's target y features are not nan and greater than zero.
567 - "other" is an optional set of parameters which will be pruned with the
568 test sets (i.e. passing a list of indices will return the indices which
570 Multiple data sets can also be passed simultaneously as a list to the
571 respective parameters, in order to filter the same samples out of all
572 data sets (e.g. OLI and S2B data, containing same samples but different
573 bands, can be filtered so they end up with the same samples relative to
578 if type(x_data)
is not list: x_data = [x_data]
579 if type(y_data)
is not list: y_data = [y_data]
580 if type(other)
is not list: other = [other]
582 both_data = [x_data, y_data]
583 set_length = [len(fullset)
for fullset
in both_data]
584 set_shape = [[len(subset)
for subset
in fullset]
for fullset
in both_data]
586 assert(np.all([length == len(x_data)
for length
in set_length])), \
587 f
'Mismatching number of subsets: {set_length}'
588 assert(np.all([[shape == len(fullset[0])
for shape
in shapes]
589 for shapes, fullset
in zip(set_shape, both_data)])), \
590 f
'Mismatching number of samples: {set_shape}'
591 assert(len(other) == 0
or all([len(o) == len(x_data[0])
for o
in other])), \
592 f
'Mismatching number of samples within other data: {[len(o) for o in other]}'
598 valid = np.ones(len(x_data[0])).astype(np.bool)
599 for i, fullset
in enumerate(both_data):
600 for subset
in fullset:
601 subset[np.isnan(subset)] = -999.
602 subset[np.logical_or(subset <= 1e-8,
not i
and (subset >= 10))] = np.nan
603 has_nan = np.any
if (i
and allow_nan_out)
or (
not i
and allow_nan_inp)
else np.all
604 valid = np.logical_and(valid, has_nan(np.isfinite(subset), 1))
606 x_data = [x[valid]
for x
in x_data]
607 y_data = [y[valid]
for y
in y_data]
608 print(f
'Removed {(~valid).sum()} invalid samples ({valid.sum()} remaining)')
609 assert(valid.sum()),
'All samples have nan or negative values'
612 return x_data, y_data, [np.array(o)[valid]
for o
in other]
613 return x_data, y_data
617 ''' Main function for gathering datasets '''
618 np.random.seed(args.seed)
619 sensor = args.sensor.split(
'-')[0]
620 products = args.product.split(
',')
625 assert(
not using_feature(args,
'ratio')),
'Too much memory needed for simulated+ratios'
626 data_folder = [
'790']
627 data_keys = [
'Rrs']+products
628 data_path = Path(args.sim_loc)
631 if products[0] ==
'all':
632 products = [
'chl',
'tss',
'cdom',
'ad',
'ag',
'aph']
636 data_path = Path(args.data_loc)
637 get_dataset =
lambda path, p: Path(path.as_posix().replace(f
'/{sensor}',
'').replace(f
'/{p}.csv',
'')).stem
639 for product
in products:
640 if product
in [
'chl',
'tss',
'cdom',
'pc']:
641 product = f
'../{product}'
644 safe_prod = product.replace(
'*',
'[*]')
645 datasets = [get_dataset(path, product)
for path
in data_path.glob(f
'*/{sensor}/{safe_prod}.csv')]
648 datasets = [d
for d
in datasets
if d
not in [
'PACE']]
650 if getattr(args,
'subset',
''):
651 datasets = [d
for d
in datasets
if d
in args.subset.split(
',')]
653 data_folder += datasets
654 data_keys += [product]
657 order_unique =
lambda a: [a[i]
for i
in sorted(np.unique(a, return_index=
True)[1])]
658 data_folder = order_unique(data_folder)
659 data_keys = order_unique(data_keys)
660 assert(len(data_folder)), f
'No datasets found for {products} within {data_path}/*/{sensor}'
661 assert(len(data_keys)), f
'No variables found for {products} within {data_path}/*/{sensor}'
663 sensor_loc = [data_path.joinpath(f, sensor)
for f
in data_folder]
664 x_data, y_data, slices, sources = _load_datasets(data_keys, sensor_loc, bands, allow_missing=(
'-nan' in args.sensor)
or (getattr(args,
'align',
None)
is not None))
668 y_data[:, slices[
'cdom']] *= 0.18
671 if getattr(args,
'align',
None)
is not None:
672 assert(
'-nan' not in args.sensor),
'Cannot allow all samples via "-nan" while also aligning to other sensors'
673 align = args.align.split(
',')
675 align = [s
for s
in SENSOR_LABELS.keys()
if s !=
'HYPER']
676 align_loc = [[data_path.joinpath(f, a.split(
'-')[0])
for f
in data_folder]
for a
in align]
678 print(f
'\nLoading alignment data for {align}...')
679 x_align, y_align, slices_align, sources_align = map(list,
680 zip(*[_load_datasets(data_keys, loc,
get_sensor_bands(a, args), allow_missing=
True)
for a, loc
in zip(align, align_loc)]))
682 x_data = [x_data] + x_align
683 y_data = [y_data] + y_align
687 above = y_data[..., slices[
'pc']].flatten() > 1000
688 below = y_data[..., slices[
'pc']].flatten() < 0.1
689 y_data[above|below, slices[
'pc']] = np.nan
692 if '-nan' not in args.sensor:
693 (x_data, *_), (y_data, *_), (sources, *_) = _filter_invalid(x_data, y_data, slices, other=[sources], allow_nan_out=
not using_feature(args,
'sim')
and len(data_keys) > 2)
697 assert(
not using_feature(args,
'sim')),
'Simulated data does not need TChl correction'
698 y_data = _fix_tchl(y_data, sources, slices, data_path)
707 print(
'\nFinal counts:')
708 print(
'\n'.join([f
'\tN={num:>5} | {loc}' for loc, num
in zip(*np.unique(sources[:, 0], return_counts=
True))]))
709 print(f
'\tTotal: {len(sources)}')
710 return x_data, y_data, slices, sources
713 def _fix_tchl(y_data, sources, slices, data_path, debug=False):
714 ''' Very roughly correct chl for pheopigments '''
717 dataset_name, sample_idx = sources.T
718 sample_idx.astype(int)
720 fix = np.ones(len(y_data)).astype(np.bool)
723 set_idx = np.where(dataset_name ==
'Sundar')[0]
724 dataset = np.loadtxt(data_path.joinpath(
'Sundar',
'Dataset.csv'), delimiter=
',', dtype=str)[sample_idx[set_idx]]
725 fix[set_idx[dataset ==
'ACIX_Krista']] =
False
726 fix[set_idx[dataset ==
'ACIX_Moritz']] =
False
728 set_idx = np.where(data_lbl ==
'SeaBASS2')[0]
729 meta = pd.read_csv(data_path.joinpath(
'SeaBASS2',
'meta.csv')).iloc[sample_idx[set_idx]]
730 lonlats = meta[[
'east_longitude',
'west_longitude',
'north_latitude',
'south_latitude']].apply(
lambda v: v.apply(
lambda v2: v2.split(
'||')[0]))
733 lonlats = lonlats.apply(
lambda v: pd.to_numeric(v.apply(
lambda v2: v2.split(
'::')[1].replace(
'[deg]',
'')),
'coerce'))
734 lonlats = lonlats[[
'east_longitude',
'north_latitude']].to_numpy()
737 fix[set_idx[np.logical_and(lonlats[:,0] < -117, lonlats[:,1] > 32)]] =
False
738 fix[y_data[:,0] > 80] =
False
739 print(f
'Correcting {fix.sum()} / {len(fix)} samples')
741 coef = [0.04, 0.776, 0.015, -0.00046, 0.000004]
743 y_data[fix, slices[
'chl']] = np.sum(np.array(coef) * y_data[fix, slices[
'chl']] ** np.arange(len(coef)), 1, keepdims=
False)
746 import matplotlib.pyplot
as plt
747 from .plot_utils
import add_identity
748 plt.scatter(old, y_data)
754 plt.xlim((y_data[y_data > 0].min()/10, y_data.max()*10))
755 plt.ylim((y_data[y_data > 0].min()/10, y_data.max()*10))