1 from .utils
import ignore_warnings
2 from scipy
import stats
8 ''' Decorator to flatten all function input arrays, and ensure shapes are the same '''
10 def helper(*args, **kwargs):
11 flat = [a.flatten()
if hasattr(a,
'flatten')
else a
for a
in args]
12 flat_shp = [a.shape
for a
in flat
if hasattr(a,
'shape')]
13 orig_shp = [a.shape
for a
in args
if hasattr(a,
'shape')]
14 assert(all(flat_shp[0] == s
for s
in flat_shp)), f
'Shapes mismatch in {func.__name__}: {orig_shp}'
15 return func(*flat, **kwargs)
20 ''' Decorator to remove samples which are nan in any input array '''
22 @functools.wraps(func)
23 def helper(*args, **kwargs):
24 stacked = np.vstack(args)
25 valid = np.all(np.isfinite(stacked), 0)
26 assert(valid.sum()), f
'No valid samples exist for {func.__name__} metric'
27 return func(*stacked[:, valid], **kwargs)
32 ''' Decorator to remove samples which are zero/negative in any input array '''
34 @functools.wraps(func)
35 def helper(*args, **kwargs):
36 stacked = np.vstack(args)
37 valid = np.all(stacked > 0, 0)
38 assert(valid.sum()), f
'No valid samples exist for {func.__name__} metric'
39 return func(*stacked[:, valid], **kwargs)
44 ''' Label a function to aid in printing '''
53 When executing a function, decorator order starts with the
54 outermost decorator and works its way down the stack; e.g.
59 And then foo == dec1(dec2(bar)). So, foo will execute dec1,
60 then dec2, then the original function.
62 Below, in rmsle (for example), we have:
63 rmsle = only_finite( only_positive( label(rmsle) ) )
64 This means only_positive() will get the input arrays only
65 after only_finite() removes any nan samples. As well, both
66 only_positive() and only_finite() will have access to the
67 function __name__ assigned by label().
69 For all functions below, y=true and y_hat=estimate
76 ''' Root Mean Squared Error '''
77 return np.mean((y - y_hat) ** 2) ** .5
84 ''' Root Mean Squared Logarithmic Error '''
85 return np.mean(np.abs(np.log(y) - np.log(y_hat)) ** 2) ** 0.5
91 ''' Normalized Root Mean Squared Error '''
92 return ((y - y_hat) ** 2).
mean() ** .5 / y.mean()
98 ''' Mean Absolute Error '''
99 return np.mean(np.abs(y - y_hat))
105 ''' Mean Absolute Percentage Error '''
106 return 100 * np.mean(np.abs((y - y_hat) / y))
112 ''' Less than or equal to zero (y_hat) '''
113 if y_hat
is None: y_hat = y
114 return (y_hat <= 0).sum()
120 ''' Less than or equal to zero (y_hat) '''
121 if y_hat
is None: y_hat = y
122 return np.logical_or(np.isnan(y_hat), y_hat <= 0).sum()
129 ''' Median Symmetric Accuracy '''
131 return 100 * (np.exp(np.median(np.abs(np.log(y_hat / y)))) - 1)
138 ''' Mean Symmetric Accuracy '''
140 return 100 * (np.exp(np.mean(np.abs(np.log(y_hat / y)))) - 1)
147 ''' Symmetric Signed Percentage Bias '''
149 M = np.median( np.log(y_hat / y) )
150 return 100 * np.sign(M) * (np.exp(np.abs(M)) - 1)
157 return np.mean(y_hat - y)
164 ''' Logarithmic R^2 '''
165 slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
173 ''' Logarithmic slope '''
174 slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
182 ''' Locarithmic intercept '''
183 slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
189 def mwr(y, y_hat, y_bench):
191 Model Win Rate - Percent of samples in which model has a closer
192 estimate than the benchmark.
193 y: true, y_hat: model, y_bench: benchmark
195 y_bench[y_bench < 0] = np.nan
196 y_hat[y_hat < 0] = np.nan
198 valid = np.logical_and(np.isfinite(y_hat), np.isfinite(y_bench))
199 diff1 = np.abs(y[valid] - y_hat[valid])
200 diff2 = np.abs(y[valid] - y_bench[valid])
201 stats = np.zeros(len(y))
202 stats[valid] = diff1 < diff2
203 stats[~np.isfinite(y_bench)] = 1
204 stats[~np.isfinite(y_hat)] = 0
205 return stats.sum() / np.isfinite(y).sum()
208 def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False):
209 ''' Return a string containing performance using various metrics.
210 y should be the true value, y_hat the estimated value. '''
212 y_hat = y_hat.flatten()
214 if csv:
return f
'{key},'+
','.join([f
'{f.__name__}:{f(y, y_hat)}' for f
in metrics])
215 else:
return f
'{key:>12} | '+
' '.join([f
'{f.__name__}: {f(y, y_hat):>6.3f}' for f
in metrics])
216 except Exception
as e:
return f
'{key:>12} | Exception: {e}'