Skip to content

Commit e59a460

Browse files
authored
Merge pull request #548 from lmfit/html_repr
alternative html_repr
2 parents d1ed894 + 785fe08 commit e59a460

File tree

5 files changed

+194
-3
lines changed

5 files changed

+194
-3
lines changed

lmfit/minimizer.py

+6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from ._ampgo import ampgo
3737
from .parameter import Parameter, Parameters
38+
from .printfuncs import fitreport_html_table
3839

3940
# check for EMCEE
4041
try:
@@ -349,6 +350,11 @@ def _calculate_statistics(self):
349350
self.aic = _neg2_log_likel + 2 * self.nvarys
350351
self.bic = _neg2_log_likel + np.log(self.ndata) * self.nvarys
351352

353+
def _repr_html_(self, show_correl=True, min_correl=0.1):
354+
"""Returns a HTML representation of parameters data."""
355+
return fitreport_html_table(self, show_correl=show_correl,
356+
min_correl=min_correl)
357+
352358

353359
class Minimizer(object):
354360
"""A general minimizer for curve fitting and optimization."""

lmfit/model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .confidence import conf_interval
1616
from .jsonutils import HAS_DILL, decode4js, encode4js
1717
from .minimizer import validate_nan_policy, MinimizerResult
18-
from .printfuncs import ci_report, fit_report
18+
from .printfuncs import ci_report, fit_report, fitreport_html_table
1919

2020
# Use pandas.isnull for aligning missing data if pandas is available.
2121
# otherwise use numpy.isnan
@@ -1565,6 +1565,13 @@ def fit_report(self, modelpars=None, show_correl=True,
15651565
modname = self.model._reprstring(long=True)
15661566
return '[[Model]]\n %s\n%s\n' % (modname, report)
15671567

1568+
def _repr_html_(self, show_correl=True, min_correl=0.1):
1569+
"""Returns a HTML representation of parameters data."""
1570+
report = fitreport_html_table(self, show_correl=show_correl,
1571+
min_correl=min_correl)
1572+
modname = self.model._reprstring(long=True)
1573+
return "<h2> Model</h2> %s %s" % (modname, report)
1574+
15681575
def dumps(self, **kws):
15691576
"""Represent ModelResult as a JSON string.
15701577

lmfit/parameter.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import uncertainties
1212

1313
from .jsonutils import decode4js, encode4js
14+
from .printfuncs import params_html_table
1415

1516
SCIPY_FUNCTIONS = {'gamfcn': scipy.special.gamma}
1617
for name in ('erf', 'erfc', 'wofz'):
@@ -284,6 +285,10 @@ def pretty_print(self, oneline=False, colwidth=8, precision=4, fmt='g',
284285
print(line.format(name_len=name_len, n=colwidth, p=precision,
285286
f=fmt, **pvalues))
286287

288+
def _repr_html_(self):
289+
"""Returns a HTML representation of parameters data."""
290+
return params_html_table(self)
291+
287292
def add(self, name, value=None, vary=True, min=-inf, max=inf, expr=None,
288293
brute_step=None):
289294
"""Add a Parameter.

lmfit/printfuncs.py

+100-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
except ImportError:
1111
HAS_NUMDIFFTOOLS = False
1212

13-
from .parameter import Parameters
14-
1513

1614
def alphanumeric_sort(s, _nsre=re.compile('([0-9]+)')):
1715
"""Sort alphanumeric string."""
@@ -111,6 +109,7 @@ def fit_report(inpars, modelpars=None, show_correl=True, min_correl=0.1,
111109
Multi-line text of fit report.
112110
113111
"""
112+
from .parameter import Parameters
114113
if isinstance(inpars, Parameters):
115114
result, params = None, inpars
116115
if hasattr(inpars, 'params'):
@@ -210,6 +209,105 @@ def fit_report(inpars, modelpars=None, show_correl=True, min_correl=0.1,
210209
return '\n'.join(buff)
211210

212211

212+
def fitreport_html_table(result, show_correl=True, min_correl=0.1):
213+
"""Report minimizer result as an html table"""
214+
html = []
215+
add = html.append
216+
217+
def stat_row(label, val, val2=''):
218+
add('<tr><td>%s</td><td>%s</td><td>%s</td></tr>' % (label, val, val2))
219+
add('<h2>Fit Statistics</h2>')
220+
add('<table>')
221+
stat_row('fitting method', result.method)
222+
stat_row('# function evals', result.nfev)
223+
stat_row('# data points', result.ndata)
224+
stat_row('# variables', result.nvarys)
225+
stat_row('chi-square', gformat(result.chisqr))
226+
stat_row('reduced chi-square', gformat(result.redchi))
227+
stat_row('Akaike info crit.', gformat(result.aic))
228+
stat_row('Bayesian info crit.', gformat(result.bic))
229+
add('</table>')
230+
add('<h2>Variables</h2>')
231+
add(result.params._repr_html_())
232+
if show_correl:
233+
correls = []
234+
parnames = list(result.params.keys())
235+
for i, name in enumerate(result.params):
236+
par = result.params[name]
237+
if not par.vary:
238+
continue
239+
if hasattr(par, 'correl') and par.correl is not None:
240+
for name2 in parnames[i+1:]:
241+
if (name != name2 and name2 in par.correl and
242+
abs(par.correl[name2]) > min_correl):
243+
correls.append((name, name2, par.correl[name2]))
244+
if len(correls) > 0:
245+
sort_correls = sorted(correls, key=lambda val: abs(val[2]))
246+
sort_correls.reverse()
247+
extra = '(unreported correlations are < %.3f)' % (min_correl)
248+
add('<h2>Correlations %s</h2>' % extra)
249+
add('<table>')
250+
for name1, name2, val in sort_correls:
251+
stat_row(name1, name2, "%.4f" % val)
252+
add('</table>')
253+
return ''.join(html)
254+
255+
256+
def params_html_table(params):
257+
"""Returns a HTML representation of parameters data."""
258+
has_err = any([p.stderr is not None for p in params.values()])
259+
has_expr = any([p.expr is not None for p in params.values()])
260+
has_brute = any([p.brute_step is not None for p in params.values()])
261+
262+
html = []
263+
add = html.append
264+
265+
def cell(x, cat='td'):
266+
return add('<%s> %s </%s>' % (cat, x, cat))
267+
268+
add('<table><tr>')
269+
headers = ['name', 'value']
270+
if has_err:
271+
headers.extend(['standard error', 'relative error'])
272+
headers.extend(['initial value', 'min', 'max', 'vary'])
273+
if has_expr:
274+
headers.append('expression')
275+
if has_brute:
276+
headers.append('brute step')
277+
for h in headers:
278+
cell(h, cat='th')
279+
add('</tr>')
280+
281+
for p in params.values():
282+
rows = [p.name, gformat(p.value)]
283+
if has_err:
284+
serr, perr = '', ''
285+
if p.stderr is not None:
286+
serr = gformat(p.stderr)
287+
perr = '%.2f%%' % (100 * p.stderr / p.value)
288+
rows.extend([serr, perr])
289+
rows.extend((p.init_value, gformat(p.min), gformat(p.max),
290+
'%s' % p.vary))
291+
if has_expr:
292+
expr = ''
293+
if p.expr is not None:
294+
expr = p.expr
295+
rows.append(expr)
296+
297+
if has_brute:
298+
brute_step = 'None'
299+
if p.brute_step is not None:
300+
brute_step = gformat(p.brute_step)
301+
rows.append(brute_step)
302+
303+
add('<tr>')
304+
for r in rows:
305+
cell(r)
306+
add('</tr>')
307+
add('</table>')
308+
return ''.join(html)
309+
310+
213311
def report_errors(params, **kws):
214312
"""Print a report for fitted params: see error_report()."""
215313
print(fit_report(params, **kws))

tests/test_fitreports.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
3+
from lmfit import Minimizer, Parameters, conf_interval, ci_report, fit_report
4+
from lmfit.models import GaussianModel
5+
from lmfit.lineshapes import gaussian
6+
7+
np.random.seed(0)
8+
9+
10+
def test_reports_created():
11+
"""do a simple Model fit but with all the bells-and-whistles
12+
and verify that the reports are created
13+
"""
14+
x = np.linspace(0, 12, 601)
15+
data = gaussian(x, amplitude=36.4, center=6.70, sigma=0.88)
16+
data = data + np.random.normal(size=len(x), scale=3.2)
17+
model = GaussianModel()
18+
params = model.make_params(amplitude=50, center=5, sigma=2)
19+
20+
params['amplitude'].min = 0
21+
params['sigma'].min = 0
22+
params['sigma'].brute_step = 0.001
23+
24+
result = model.fit(data, params, x=x)
25+
26+
report = result.fit_report()
27+
assert(len(report) > 500)
28+
29+
html_params = result.params._repr_html_()
30+
assert(len(html_params) > 500)
31+
32+
html_report = result._repr_html_()
33+
assert(len(html_report) > 1000)
34+
35+
36+
def test_ci_report():
37+
"""test confidence interval report"""
38+
39+
def residual(pars, x, data=None):
40+
argu = (x*pars['decay'])**2
41+
shift = pars['shift']
42+
if abs(shift) > np.pi/2:
43+
shift = shift - np.sign(shift)*np.pi
44+
model = pars['amp']*np.sin(shift + x/pars['period']) * np.exp(-argu)
45+
if data is None:
46+
return model
47+
return model - data
48+
49+
p_true = Parameters()
50+
p_true.add('amp', value=14.0)
51+
p_true.add('period', value=5.33)
52+
p_true.add('shift', value=0.123)
53+
p_true.add('decay', value=0.010)
54+
55+
n = 2500
56+
xmin = 0.
57+
xmax = 250.0
58+
x = np.linspace(xmin, xmax, n)
59+
data = residual(p_true, x) + np.random.normal(scale=0.7215, size=n)
60+
61+
fit_params = Parameters()
62+
fit_params.add('amp', value=13.0)
63+
fit_params.add('period', value=2)
64+
fit_params.add('shift', value=0.0)
65+
fit_params.add('decay', value=0.02)
66+
67+
mini = Minimizer(residual, fit_params, fcn_args=(x,),
68+
fcn_kws={'data': data})
69+
out = mini.leastsq()
70+
report = fit_report(out)
71+
assert(len(report) > 500)
72+
73+
ci, tr = conf_interval(mini, out, trace=True)
74+
report = ci_report(ci)
75+
assert(len(report) > 250)

0 commit comments

Comments
 (0)