OB.DAAC Logo
NASA Logo
Ocean Color Science Software

ocssw V2022
utils.py
Go to the documentation of this file.
1 from scipy.optimize import minimize
2 from functools import partial, update_wrapper, wraps
3 from importlib import import_module
4 from pathlib import Path
5 import numpy as np
6 import warnings, pkgutil
7 
8 
9 def loadtxt(filename, delimiter=','):
10  root_dir = Path(__file__).parent
11  filename = root_dir.joinpath(filename)
12  return np.loadtxt(filename, delimiter=delimiter)
13 
14 
15 def find_wavelength(k, waves, validate=True, tol=5, squeeze=False):
16  ''' Index of closest wavelength '''
17  waves = np.array(waves)
18  w = np.atleast_1d(k)
19  i = np.abs(waves - w[:, None]).argmin(-1)
20  assert(not validate or (np.abs(w-waves[i]).max() <= tol)), f'Needed {k}, but closest was {waves[i]} in {waves} ({np.abs(w-waves[i]).max()} > {tol})'
21  return i.reshape(np.array(k).shape) if squeeze else i
22 
23 
24 def closest_wavelength(k, waves, validate=True, tol=5, squeeze=False):
25  ''' Value of closest wavelength '''
26  waves = np.array(waves)
27  return waves[find_wavelength(k, waves, validate, tol, squeeze)]
28 
29 
30 def has_band(w, waves, tol=5):
31  ''' Ensure band exists within <tol> nm '''
32  return np.abs(w - closest_wavelength(w, np.array(waves), validate=False)) <= tol
33 
34 
35 def to_rrs(Rrs):
36  ''' Conversion to subsurface reflectance (Lee et al. 2002) '''
37  return Rrs / (0.52 + 1.7 * Rrs)
38 
39 
40 def to_Rrs(rrs):
41  ''' Inverse of to_rrs - conversion from subsurface to remote sensing reflectance '''
42  return (rrs * 0.52) / (1 - rrs * 1.7)
43 
44 
45 def get_required(Rrs, waves, required=[], tol=5, squeeze=False):
46  '''
47  Checks that all required wavelengths are available in the given data.
48  Returns an object which acts as a functional interface into the Rrs data,
49  allowing a wavelength or set of wavelengths to be returned:
50  Rrs = get_required(Rrs, ...)
51  Rrs(443) # Returns a matrix containing the band data closest to 443nm (shape [N, 1])
52  Rrs([440, 740]) # Returns a matrix containing the band data closest to 440nm, and to 740nm (shape [N, 2])
53  '''
54  waves = np.array(waves)
55  Rrs = np.atleast_2d(Rrs)
56  assert(Rrs.shape[-1] == len(waves)), \
57  f'Shape mismatch: Rrs={Rrs.shape}, wavelengths={len(waves)}'
58  assert(all([has_band(w, waves, tol) for w in required])), \
59  f'At least one of {required} is missing from {waves}'
60  return lambda w, validate=True: Rrs[..., find_wavelength(w, waves, tol=tol, validate=validate, squeeze=squeeze)] if w is not None else Rrs
61 
62 
63 def get_benchmark_models(products, allow_opt=False, debug=False, method=None):
64  '''
65  Return all benchmark models within the product directory, as well as those
66  within the 'multiple' directory. Note that this means some models returned
67  will not be applicable to the given product(s), and will need to be filtered.
68  If allow_opt=True, models requiring optimization will also be returned.
69  '''
70  products = list(np.atleast_1d(products))
71  models = {}
72  for product in products + ['multiple']:
73  benchmark_dir = Path(__file__).parent.resolve()
74  product_dir = benchmark_dir.joinpath(Path(product).stem)
75  assert(product_dir.exists()), f'No directory exists for the product "{product}" within {benchmark_dir}'
76 
77  # Iterate over all benchmark algorithm folders in the appropriate product directory
78  for (_, name, is_folder) in pkgutil.iter_modules([product_dir]):
79  if is_folder and name[0] != '_':
80 
81  module = Path(__file__).parent.parent.stem
82  imported = import_module(f'{module}.{benchmark_dir.stem}.{product_dir.stem}.{name}.model')
83  for function in dir(imported):
84 
85  # Check all functions which have "model" in their name
86  if 'model' in function:
87  model = getattr(imported, function)
88 
89  # Return models which have default parameters, or all if allowing optimization
90  if getattr(model, 'has_default', False) or allow_opt:
91 
92  # Within the 'multiple' directory, ensure model outputs contain a requested product
93  if product != 'multiple' or any(p in model._output_keys for p in products):
94  model.__name__ = model.__dict__['__name__'] = name = getattr(model, 'model_name', name)
95  models[name] = model
96 
97  elif debug:
98  print(f'{name} requires optimization')
99 
100  assert(method is None or method in models), f'Unknown algorithm "{method}". Options are: \n{list(models.keys())}'
101  return models if method is None else {method: models[method]}
102 
103 
105  ''' Context manager to temporarily set the global random state for
106  any methods which aren't using a seed or local random state. '''
107 
108  def __init__(self, seed=None):
109  self.seed = seed
110  self.state = None
111 
112  def __enter__(self):
113  self.state = np.random.get_state()
114  np.random.seed(self.seed)
115 
116  def __exit__(self, *args, **kwargs):
117  np.random.set_state(self.state)
118 
119 
120 class Optimizer:
121  ''' Allow benchmark function parameters to be optimized via a set of training data '''
122 
123  def __init__(self, function, opt_vars, has_default):
124  self.function = self.trained_function = function
125  self.opt_vars = opt_vars
126  self.has_default = has_default
127 
128  def __call__(self, *args, **kwargs):
129  with warnings.catch_warnings():
130  warnings.filterwarnings('ignore')
131  return self.function(*args, **kwargs)
132 
133  def fit(self, X, Y, wavelengths):
134  def cost_func(guess):
135  assert(np.all(np.isfinite(guess))), guess
136  guess = dict(zip(self.opt_vars, guess))
137  return np.nanmedian(np.abs((self(X, wavelengths, **guess) - Y) / Y))
138  return np.abs(np.nanmean(self(X, wavelengths, **guess) - Y))
139  return ((self(X, wavelengths, **guess) - Y) ** 2).sum() ** 0.5
140  from skopt import gbrt_minimize
141  init = [(1e-2,100)]*len(self.opt_vars)
142  # res = minimize(cost_func, init, tol=1e-6, options={'maxiter':1e3}, method='BFGS')
143  res = gbrt_minimize(cost_func, init, n_random_starts=10000, n_calls=10000)#, method='SLSQP')#, tol=1e-10, options={'maxiter':1e5}, method='SLSQP')
144  print(self.__name__, res.x, res.fun)
145  self.trained_function = partial(self.function, wavelengths=wavelengths, **dict(zip(self.opt_vars, res.x)))
146 
147  def predict(self, *args, **kwargs):
148  return self.trained_function(*args, **kwargs)
149 
150 
151 def optimize(opt_vars, has_default=True):
152  ''' Can automatically optimize a function
153  with a given set of variables, using the
154  first set of data given. Then, return the
155  optimized function as partially defined, using
156  the optimal parameters
157  '''
158 
159  def function_wrapper(function):
160  return update_wrapper(Optimizer(function, opt_vars, has_default), function)
161  return function_wrapper
162 
163 
164 def set_outputs(output_keys):
165  ''' All models within the 'multiple' folder should be decorated with this,
166  as the model output should be a dictionary. This decorator takes as input
167  a list of products (the keys within the output dict) and makes them
168  available to check, without needing to run the model beforehand. As well,
169  models can take 'product' as a keyword argument, and will then return only
170  that product.
171  '''
172  def function_wrapper(function):
173 
174  @wraps(function)
175  def select_output(*args, **kwargs):
176  # If product is given as a keyword argument and it is contained in
177  # the output dictionary, return the requested output. Otherwise,
178  # return the entire dictionary.
179  output = function(*args, **kwargs)
180  product = kwargs.get('product', None)
181  return output.get(product, output)
182 
183  setattr(select_output, '_output_keys', output_keys)
184  return select_output
185  return function_wrapper
186 
187 
188 # def set_name(name, extra_kws={}):
189 # ''' Set the model name to be different than the containing folder '''
190 # def function_wrapper(function):
191 # function.model_name = name
192 # if hasattr(function, 'function'):
193 # new_function = partial(function.function, **extra_kws)
194 # function.function = update_wrapper(new_function, function.function)
195 # else:
196 # new_function = partial(function, **extra_kws)
197 # function = update_wrapper(new_function, function)
198 # return function
199 # return function_wrapper
def loadtxt(filename, delimiter=',')
Definition: utils.py:9
def has_band(w, waves, tol=5)
Definition: utils.py:30
list(APPEND LIBS ${PGSTK_LIBRARIES}) add_executable(atteph_info_modis atteph_info_modis.c) target_link_libraries(atteph_info_modis $
Definition: CMakeLists.txt:7
def get_benchmark_models(products, allow_opt=False, debug=False, method=None)
Definition: utils.py:63
def __exit__(self, *args, **kwargs)
Definition: utils.py:116
def __init__(self, seed=None)
Definition: utils.py:108
def __init__(self, function, opt_vars, has_default)
Definition: utils.py:123
def optimize(opt_vars, has_default=True)
Definition: utils.py:151
def get_required(Rrs, waves, required=[], tol=5, squeeze=False)
Definition: utils.py:45
def predict(self, *args, **kwargs)
Definition: utils.py:147
def closest_wavelength(k, waves, validate=True, tol=5, squeeze=False)
Definition: utils.py:24
def to_rrs(Rrs)
Definition: utils.py:35
def __call__(self, *args, **kwargs)
Definition: utils.py:128
def fit(self, X, Y, wavelengths)
Definition: utils.py:133
def set_outputs(output_keys)
Definition: utils.py:164
def find_wavelength(k, waves, validate=True, tol=5, squeeze=False)
Definition: utils.py:15
def to_Rrs(rrs)
Definition: utils.py:40