Go to the documentation of this file.
2 Mixture Density Network (MDN) ocean color water quality product retrieval
4 MDN is a machine learning algorithm, trained to use remote sensing reflectance (Rrs)
5 to estimate various water quality products. This package includes the model which
6 retrieves Chlorophyll-a (chl), Total Suspended Solids (tss), Colored Dissolved Organic
7 Matter at 440nm (cdom), and Phycocyanin Concentration (pc) from HICO imagery.
9 To utilize this package, activate the provided virtual environment and call the script:
10 $ source venv/Scripts/activate
11 $ python create_images.py
13 Code base can additionally be found at:
14 https://github.com/BrandonSmithJ/MDN
15 https://github.com/STREAM-RS/STREAM-RS
17 Brandon Smith, NASA Goddard Space Flight Center, October 2021
20 from netCDF4
import Dataset
21 from pathlib
import Path
24 from MDN.utils import get_sensor_bands, closest_wavelength
25 from MDN
import image_estimates
27 import matplotlib.colors
as colors
28 import matplotlib.pyplot
as plt
34 ''' Apply gamma stretching to brighten imagery '''
35 return (255. * data ** 0.5).astype(np.uint8)
39 def extract_data(image, avail_bands, req_bands, allow_neg=False, key='Rrs'):
40 ''' Extract the requested bands from a given NetCDF object '''
42 def extract(requested):
44 return np.ma.stack([image[f
'{key}_{band}'][:]
for band
in bands], axis=-1)
47 extracted = extract(req_bands)
51 extracted[extracted <= 0] = np.nan
54 return extracted.filled(fill_value=np.nan)
59 ''' Plot a given product on the axis using vmin/vmax as the
60 colorbar min/max, and rgb as the visible background '''
63 ax.set_title(key.upper())
65 norm = colors.LogNorm(vmin=vmin, vmax=vmax)
66 img = ax.imshow(np.squeeze(product), norm=norm, cmap=
'turbo')
67 plt.colorbar(img, ax=ax)
72 if __name__ ==
'__main__':
76 'product' :
'chl,tss,cdom,pc',
79 'use_excl_Rrs' :
True,
84 rgb_bands = [660, 550, 440]
86 for location
in Path(f
'{sensor}-imagery').glob(
'*'):
87 time_start = time.time()
90 image = Dataset(location.joinpath(
'l2gen.nc'))[
'geophysical_data']
91 bands = sorted([
int(k.replace(
'Rrs_',
''))
for k
in image.variables.keys()
if 'Rrs_' in k])
99 f, axes = plt.subplots(1, len(slices), figsize=(4*len(slices), 8))
106 for i, (key, idx)
in enumerate(slices.items()):
107 plot_product(np.atleast_1d(axes)[i], key, products[..., idx], rgb, *bounds[key])
109 plt.savefig(f
'{sensor}_{location.stem}.png')
112 print(f
'Generated {sensor}_{location.stem}.png in {time.time()-time_start:.1f} seconds')
def image_estimates(data, sensor=None, function=apply_model, rhos=False, anc=False, **kwargs)
def extract_data(image, avail_bands, req_bands, allow_neg=False, key='Rrs')
def gamma_stretch(data, gamma=2)
def get_args(kwargs={}, use_cmdline=True, **kwargs2)
def closest_wavelength(k, waves, validate=True, tol=5, squeeze=False)
def plot_product(ax, title, product, rgb, vmin, vmax)