1 from scipy.optimize
import minimize
2 from functools
import partial, update_wrapper, wraps
3 from importlib
import import_module
4 from pathlib
import Path
6 import warnings, pkgutil
10 root_dir = Path(__file__).parent
11 filename = root_dir.joinpath(filename)
12 return np.loadtxt(filename, delimiter=delimiter)
16 ''' Index of closest wavelength '''
17 waves = np.array(waves)
19 i = np.abs(waves - w[:,
None]).argmin(-1)
20 assert(
not validate
or (np.abs(w-waves[i]).max() <= tol)), f
'Needed {k}, but closest was {waves[i]} in {waves} ({np.abs(w-waves[i]).max()} > {tol})'
21 return i.reshape(np.array(k).shape)
if squeeze
else i
25 ''' Value of closest wavelength '''
26 waves = np.array(waves)
31 ''' Ensure band exists within <tol> nm '''
36 ''' Conversion to subsurface reflectance (Lee et al. 2002) '''
37 return Rrs / (0.52 + 1.7 * Rrs)
41 ''' Inverse of to_rrs - conversion from subsurface to remote sensing reflectance '''
42 return (rrs * 0.52) / (1 - rrs * 1.7)
47 Checks that all required wavelengths are available in the given data.
48 Returns an object which acts as a functional interface into the Rrs data,
49 allowing a wavelength or set of wavelengths to be returned:
50 Rrs = get_required(Rrs, ...)
51 Rrs(443) # Returns a matrix containing the band data closest to 443nm (shape [N, 1])
52 Rrs([440, 740]) # Returns a matrix containing the band data closest to 440nm, and to 740nm (shape [N, 2])
54 waves = np.array(waves)
55 Rrs = np.atleast_2d(Rrs)
56 assert(Rrs.shape[-1] == len(waves)), \
57 f
'Shape mismatch: Rrs={Rrs.shape}, wavelengths={len(waves)}'
58 assert(all([
has_band(w, waves, tol)
for w
in required])), \
59 f
'At least one of {required} is missing from {waves}'
60 return lambda w, validate=
True: Rrs[...,
find_wavelength(w, waves, tol=tol, validate=validate, squeeze=squeeze)]
if w
is not None else Rrs
65 Return all benchmark models within the product directory, as well as those
66 within the 'multiple' directory. Note that this means some models returned
67 will not be applicable to the given product(s), and will need to be filtered.
68 If allow_opt=True, models requiring optimization will also be returned.
70 products =
list(np.atleast_1d(products))
72 for product
in products + [
'multiple']:
73 benchmark_dir = Path(__file__).parent.resolve()
74 product_dir = benchmark_dir.joinpath(Path(product).stem)
75 assert(product_dir.exists()), f
'No directory exists for the product "{product}" within {benchmark_dir}'
78 for (_, name, is_folder)
in pkgutil.iter_modules([product_dir]):
79 if is_folder
and name[0] !=
'_':
81 module = Path(__file__).parent.parent.stem
82 imported = import_module(f
'{module}.{benchmark_dir.stem}.{product_dir.stem}.{name}.model')
83 for function
in dir(imported):
86 if 'model' in function:
87 model = getattr(imported, function)
90 if getattr(model,
'has_default',
False)
or allow_opt:
93 if product !=
'multiple' or any(p
in model._output_keys
for p
in products):
94 model.__name__ = model.__dict__[
'__name__'] = name = getattr(model,
'model_name', name)
98 print(f
'{name} requires optimization')
100 assert(method
is None or method
in models), f
'Unknown algorithm "{method}". Options are: \n{list(models.keys())}'
101 return models
if method
is None else {method: models[method]}
105 ''' Context manager to temporarily set the global random state for
106 any methods which aren't using a seed or local random state. '''
113 self.
state = np.random.get_state()
114 np.random.seed(self.
seed)
117 np.random.set_state(self.
state)
121 ''' Allow benchmark function parameters to be optimized via a set of training data '''
123 def __init__(self, function, opt_vars, has_default):
129 with warnings.catch_warnings():
130 warnings.filterwarnings(
'ignore')
131 return self.
function(*args, **kwargs)
133 def fit(self, X, Y, wavelengths):
134 def cost_func(guess):
135 assert(np.all(np.isfinite(guess))), guess
136 guess = dict(zip(self.
opt_vars, guess))
137 return np.nanmedian(np.abs((self(X, wavelengths, **guess) - Y) / Y))
138 return np.abs(np.nanmean(self(X, wavelengths, **guess) - Y))
139 return ((self(X, wavelengths, **guess) - Y) ** 2).sum() ** 0.5
140 from skopt
import gbrt_minimize
141 init = [(1e-2,100)]*len(self.
opt_vars)
143 res = gbrt_minimize(cost_func, init, n_random_starts=10000, n_calls=10000)
144 print(self.__name__, res.x, res.fun)
152 ''' Can automatically optimize a function
153 with a given set of variables, using the
154 first set of data given. Then, return the
155 optimized function as partially defined, using
156 the optimal parameters
159 def function_wrapper(function):
160 return update_wrapper(
Optimizer(function, opt_vars, has_default), function)
161 return function_wrapper
165 ''' All models within the 'multiple' folder should be decorated with this,
166 as the model output should be a dictionary. This decorator takes as input
167 a list of products (the keys within the output dict) and makes them
168 available to check, without needing to run the model beforehand. As well,
169 models can take 'product' as a keyword argument, and will then return only
172 def function_wrapper(function):
175 def select_output(*args, **kwargs):
179 output = function(*args, **kwargs)
180 product = kwargs.get(
'product',
None)
181 return output.get(product, output)
183 setattr(select_output,
'_output_keys', output_keys)
185 return function_wrapper