OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
create_images.py
Go to the documentation of this file.
1 '''
2 Mixture Density Network (MDN) ocean color water quality product retrieval
3 
4 MDN is a machine learning algorithm, trained to use remote sensing reflectance (Rrs)
5 to estimate various water quality products. This package includes the model which
6 retrieves Chlorophyll-a (chl), Total Suspended Solids (tss), Colored Dissolved Organic
7 Matter at 440nm (cdom), and Phycocyanin Concentration (pc) from HICO imagery.
8 
9 To utilize this package, activate the provided virtual environment and call the script:
10  $ source venv/Scripts/activate
11  $ python create_images.py
12 
13 Code base can additionally be found at:
14  https://github.com/BrandonSmithJ/MDN
15  https://github.com/STREAM-RS/STREAM-RS
16 
17 Brandon Smith, NASA Goddard Space Flight Center, October 2021
18 '''
19 
20 from netCDF4 import Dataset
21 from pathlib import Path
22 
23 from MDN.parameters import get_args
24 from MDN.utils import get_sensor_bands, closest_wavelength
25 from MDN import image_estimates
26 
27 import matplotlib.colors as colors
28 import matplotlib.pyplot as plt
29 import numpy as np
30 import time
31 
32 
33 def gamma_stretch(data, gamma=2):
34  ''' Apply gamma stretching to brighten imagery '''
35  return (255. * data ** 0.5).astype(np.uint8)
36 
37 
38 
39 def extract_data(image, avail_bands, req_bands, allow_neg=False, key='Rrs'):
40  ''' Extract the requested bands from a given NetCDF object '''
41 
42  def extract(requested):
43  bands = [closest_wavelength(band, avail_bands) for band in requested]
44  return np.ma.stack([image[f'{key}_{band}'][:] for band in bands], axis=-1)
45 
46  # Extract the requested bands from the image object
47  extracted = extract(req_bands)
48 
49  # Set any values <= 0 to nan if we disallow negatives
50  if not allow_neg:
51  extracted[extracted <= 0] = np.nan
52 
53  # Return the data, filling any masked values with nan
54  return extracted.filled(fill_value=np.nan)
55 
56 
57 
58 def plot_product(ax, title, product, rgb, vmin, vmax):
59  ''' Plot a given product on the axis using vmin/vmax as the
60  colorbar min/max, and rgb as the visible background '''
61  ax.imshow( gamma_stretch(rgb) )
62  ax.axis('off')
63  ax.set_title(key.upper())
64 
65  norm = colors.LogNorm(vmin=vmin, vmax=vmax)
66  img = ax.imshow(np.squeeze(product), norm=norm, cmap='turbo')
67  plt.colorbar(img, ax=ax)
68 
69 
70 
71 
72 if __name__ == '__main__':
73  sensor = 'HICO'
74  kwargs = {
75  'sensor' : sensor,
76  'product' : 'chl,tss,cdom,pc',
77  'sat_bands' : True,
78  'use_ratio' : True,
79  'use_excl_Rrs' : True,
80  }
81 
82  # Load the bands required for the given sensor
83  req_bands = get_sensor_bands(sensor, get_args(**kwargs))
84  rgb_bands = [660, 550, 440]
85 
86  for location in Path(f'{sensor}-imagery').glob('*'):
87  time_start = time.time()
88 
89  # Load HICO data, using rhos as the visible background
90  image = Dataset(location.joinpath('l2gen.nc'))['geophysical_data']
91  bands = sorted([int(k.replace('Rrs_', '')) for k in image.variables.keys() if 'Rrs_' in k])
92  Rrs = extract_data(image, bands, req_bands)
93  rgb = extract_data(image, bands, rgb_bands, key='rhos')
94 
95  # Generate product estimates - 'slices' contains the index of each product within 'products'
96  products, slices = image_estimates(Rrs, **kwargs)
97 
98  # Create plot for each product, bounding the colorbar per product
99  f, axes = plt.subplots(1, len(slices), figsize=(4*len(slices), 8))
100  bounds = {
101  'chl' : (1, 100),
102  'tss' : (1, 100),
103  'pc' : (1, 100),
104  'cdom': (0.1, 10),
105  }
106  for i, (key, idx) in enumerate(slices.items()):
107  plot_product(np.atleast_1d(axes)[i], key, products[..., idx], rgb, *bounds[key])
108  plt.tight_layout()
109  plt.savefig(f'{sensor}_{location.stem}.png')
110  plt.clf()
111 
112  print(f'Generated {sensor}_{location.stem}.png in {time.time()-time_start:.1f} seconds')
113 
def get_sensor_bands(sensor, args=None)
Definition: meta.py:114
def image_estimates(data, sensor=None, function=apply_model, rhos=False, anc=False, **kwargs)
def extract_data(image, avail_bands, req_bands, allow_neg=False, key='Rrs')
def gamma_stretch(data, gamma=2)
def get_args(kwargs={}, use_cmdline=True, **kwargs2)
Definition: parameters.py:100
def closest_wavelength(k, waves, validate=True, tol=5, squeeze=False)
Definition: utils.py:24
def plot_product(ax, title, product, rgb, vmin, vmax)