* wpylib.math.fitting: Imported fit_func_base from Cr2 project.

* Added documentation on fitting methods.
master
Wirawan Purwanto 9 years ago
parent 4111dc2da7
commit 82a55b940a
  1. 192
      math/fitting/__init__.py

@ -7,10 +7,58 @@
# Imported 20100120 from $PWQMC77/expt/Hybrid-proj/analyze-Eh.py
# (dated 20090323).
#
# fit_func_base was imported 20150520 from Cr2_analysis_cbs.py
# (dated 20141017, CVS rev 1.143).
#
# Some references on fitting:
# * http://stackoverflow.com/questions/529184/simple-multidimensional-curve-fitting
# * http://www.scipy.org/Cookbook/OptimizationDemo1 (not as thorough, but maybe useful)
"""
wpylib.math.fitting
Basic tools for two-dimensional curve fitting
ABOUT THE FITTING METHODS
We depend on module scipy.optimize and (optionally) lmfit to provide the
minimization routines for us.
The following methods are currently supported for scipy.optimize:
* `fmin`
The Nelder-Mead Simplex algorithm.
* `fmin_bfgs` or `bfgs`
The Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm
* `anneal`
Similated annealing algorithm
* `leastsq`
The Levenberg-Marquardt nonlinear least square (NLLS) method
See the documentation of `scipy.optimize` for more details.
The `fmin` algorithm is the slowest although it is fairly foor proof to
converge it (it may take many iterations).
The leastsq` algorithm is the best but it requires parameter guess that is
reasonable.
I don't have much success with `anneal`--it seems to behave erratically in
my limited experience. YMMV.
The lmfit package is supported if it can be found at runtime.
This package provides richer set of features, including constraints on
parameters and parameter interdependency.
Various minimization methods under this package are available.
To use lmfit, use keyword `lmfit:<method>` as the fit method name.
Example: `lmfit:leastsq`.
See the documentation here:
http://cars9.uchicago.edu/software/python/lmfit/fitting.html#fit-methods-label
"""
import numpy
import scipy.optimize
from wpylib.db.result_base import result_base
@ -379,3 +427,147 @@ def fit_func(Funct, Data=None, Guess=None, Params=None,
raise ValueError, "Invalid `outfmt' argument = " + x
# Imported 20150520 from Cr2_analysis_cbs.py .
class fit_func_base(object):
"""Base class for function 2-D fitting object.
This is an enhanced OO interface to fit_func.
In the derived class, a __call__ method must be implemented with
this prototype:
def __call__(self, C, x)
where
- `C' is the parameters which we sought through the fitting
procedure, and
- `x' is the x values of the data samples against which we want
to do the curve fitting.
A few user-adjustable parameters need to be attached as attributes
to this object:
- fit_method
- fit_opts (a dict or multi_fit_opts object)
- debug
- dbg_params
- Params
`fit_method' is a string containing the name of the fitting method to use,
see this module document.
Additional attributes are required to support lmfit-based fitting:
- param_names: a list/tuple of parameter names, in the same order as in
the legacy 'C' __call__ argument above.
The input-data-based automatic parameter guess is specified via Guess parameter.
See wpylib.math.fitting.fit_func for detail.
- if Guess==None (default), then it attempts to use self.Guess_xy() method
(better, new default) or old self.Guess() method.
- if Guess==False (only for lmfit case), existing values from Params object
will be used.
- TODO: dict-like Guess should be made possible.
- otherwise, the guess values will be used as the initial values.
"""
class multi_fit_opts(dict):
"""A class for defining default control parameters for different fit methods.
The fit method name is the dict key, and the value, which is also a dict,
is the default set of fitting control parameters for that particular fit method.
"""
pass
# Some reasonable parameters are set:
fit_default_opts = multi_fit_opts(
fmin=dict(xtol=1e-5, maxfun=100000, maxiter=10000, disp=0),
fmin_bfgs=dict(gtol=1e-6, disp=0),
leastsq=dict(xtol=1e-8, epsfcn=1e-6),
)
fit_default_opts["lmfit:leastsq"] = dict(xtol=1e-8, epsfcn=1e-6)
debug = 1
dbg_params = 1
fit_method = 'fmin'
fit_opts = fit_default_opts
#fit_opts = dict(xtol=1e-5, maxfun=100000, maxiter=10000, disp=0)
def fit(self, x, y, dy=None, fit_opts=None, Funct_hook=None, Guess=None):
"""Main entry function for fitting."""
x = numpy.asarray(x)
if len(x.shape) == 1:
# fix common "mistake" for 1-D domain: make it 2-D
x = x.reshape((1, x.shape[0]))
if fit_opts == None:
# Use class default if it is available
fit_opts = getattr(self, "fit_opts", {})
if isinstance(fit_opts, self.multi_fit_opts): # multiple choice :-)
fit_opts = fit_opts.get(self.fit_method, {})
if Guess == None:
Guess = getattr(self, "Guess", None)
if self.dbg_params:
self.dbg_params_log = []
if self.debug >= 5:
print "fit: Input Params = ", getattr(self, "Params", None)
self.last_fit = fit_func(
Funct=self,
Funct_hook=Funct_hook,
x=x, y=y, dy=dy,
Guess=Guess,
Params=getattr(self, "Params", None),
method=self.fit_method,
opts=fit_opts,
debug=self.debug,
outfmt=0, # yield full result
)
if self.use_lmfit_method:
if not hasattr(self, "Params"):
self.Params = self.last_fit.params
return self.last_fit['xopt']
def func_call_hook(self, C, x, y):
"""Common hook function called when calling 'THE'
function, e.g. for debugging purposes."""
from copy import copy
if self.dbg_params:
if not hasattr(self, "dbg_params_log"):
self.dbg_params_log = []
self.dbg_params_log.append(copy(C))
#print "Call morse2_fit_func(%s, %s) -> %s" % (C, x, y)
def get_params(self, C, *names):
"""Special support function to extract the values (or
representative objects) of the parameters contained in 'C',
the list of parameters.
In the legacy case, C is simply a tuple/list of numbers.
In the lmfit case, C is a Parameters object.
"""
try:
from lmfit import Parameters
# new way: using lmfit.Parameters:
if isinstance(C, Parameters):
return tuple(C[k].value for k in names)
except:
pass
# old way: using positional parameters
return tuple(C)
@property
def use_lmfit_method(self):
return self.fit_method.startswith("lmfit:")
@staticmethod
def domain_array(x):
"""Creates a domain array (x) for nonlinear fitting.
Also accomodates a common "mistake" for 1-D domain by making it
correctly 2-D in shape.
"""
x = numpy.asarray(x)
if len(x.shape) == 1:
# fix common "mistake" for 1-D domain: make it 2-D
x = x.reshape((1, x.shape[0]))
return x

Loading…
Cancel
Save