Optimizer
DataDrivenDiffEq
comes with some implementations for sparse regression included. All of these are stored inside the DataDrivenDiffEq.Optimize
package and extend the AbstractOptimizer
, if an explicit optimization is needed or the AbstractSubspaceOptimizer
for an implicit problem (where the solution is within the nullspace).
Functions
DataDrivenDiffEq.Optimize.STRRidge
— TypeSTRRidge(λ = 0.1)
STRRidge
is taken from the original paper on SINDY and implements a sequentially thresholded least squares iteration. λ
is the threshold of the iteration. It is based upon this matlab implementation.
Example
opt = STRRidge()
opt = STRRidge(1e-1)
DataDrivenDiffEq.Optimize.ADMM
— TypeADMM()
ADMM(λ, ρ)
ADMM
is an implementation of Lasso using the alternating direction methods of multipliers and loosely based on this implementation.
λ
is the sparsification parameter, ρ
the augmented Lagrangian parameter.
Example
opt = ADMM()
opt = ADMM(1e-1, 2.0)
DataDrivenDiffEq.Optimize.SR3
— TypeSR3(λ, ν, R)
SR3(λ = 1e-1, ν = 1.0)
SR3
is an optimizer framework introduced by Zheng et. al., 2018 and used within Champion et. al., 2019. SR3
contains a sparsification parameter λ
, a relaxation ν
, and a corresponding penalty function R
, which should be taken from ProximalOperators.jl.
Examples
opt = SR3()
opt = SR3(1e-2)
opt = SR3(1e-3, 1.0)
DataDrivenDiffEq.Optimize.ADM
— TypeADM()
ADM(λ = 0.1)
Optimizer for finding a sparse basis vector in a subspace based on this paper. λ
is the weight for the soft-thresholding operation.
Implementing New Optimizer
Similarly to Algorithms for Estimation, the extension of optimizers is more or less straightforward. Suppose you want to define a new optimizer MyOpt
, which should solve $A~X = Y$ for a sparse $X$.
mutable struct MyOpt <: DataDrivenDiffEq.Optimize.AbstractOptimizer
threshold
end
To use MyOpt
within SINDy
, an init!
function has to be implemented.
function init!(X::AbstractArray, o::MyOpt, A::AbstractArray, Y::AbstractArray)
X .= A \ Y
end
To perform thresholding - and use maybe for searching the right threshold - a setter and getter is required:
set_threshold!(opt::MyOpt, threshold) = opt.threshold .= threshold
get_threshold(opt::MyOpt) = opt.threshold
And, at last, the method which fits the data and returns the iterations needed:
function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::MyOpt; maxiter, convergence_error)
# Compute awesome stuff here
return iterations
end