Comments (10)
By referencing the scikit-learn
documentation (eg, see here) and experimenting with parameters, I was able to match the results of the two models:
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.preprocessing import OneHotEncoder
X_, y_ = load_veterans_lung_cancer()
X = OneHotEncoder().fit_transform(X_)
f = CoxPHSurvivalAnalysis(alpha=1, n_iter=100000)
g = CoxnetSurvivalAnalysis(alphas=[1/X.shape[0]], alpha_min_ratio=1, n_alphas=1, l1_ratio=1e-16, tol=1e-09, normalize=False)
print(f)
print(g)
f.fit(X, y_)
g.fit(X, y_)
print(f.coef_)
print(g.coef_[:,0])
print('Relative errors:\n', (f.coef_ - g.coef_[:,0]) / f.coef_)
Output:
CoxPHSurvivalAnalysis(alpha=1, n_iter=100000, tol=1e-09, verbose=0)
CoxnetSurvivalAnalysis(alpha_min_ratio=1, alphas=[0.0072992700729927005],
copy_X=True, l1_ratio=1e-16, max_iter=100000, n_alphas=1,
normalize=False, penalty_factor=None, tol=1e-09, verbose=False)
[-8.15647317e-03 -6.63327564e-01 -2.36950223e-01 -1.04657588e+00
-3.26009359e-02 -2.71347095e-04 5.39815287e-02 2.89290932e-01]
[-8.15323389e-03 -6.63228135e-01 -2.36836564e-01 -1.04645919e+00
-3.25993975e-02 -2.71575816e-04 5.39863418e-02 2.89279963e-01]
Relative errors:
[ 3.97142155e-04 1.49894366e-04 4.79676516e-04 1.11497071e-04
4.71911560e-05 -8.42911801e-04 -8.91616158e-05 3.79195919e-05]
This is reassuring.
from scikit-survival.
Now that the two models can give the same results, we can benchmark their performance and I find that CoxnetSurvivalAnalysis
is about 35X faster than CoxPHSurvivalAnalysis
. Given that CoxnetSurvivalAnalysis
also implements a more general model (L1+L2 regularization), it makes sense to use it only for now, until one needs to dig into the details of the algorithm execution (where CoxPHSurvivalAnalysis
provides better transparency).
from scikit-survival.
CoxPHSurvivalAnalysis
and CoxnetSurvivalAnalysis
are related but do not implement the same model. As you summarized, the difference is in penalty terms. Thus, the coefficients found by CoxPHSurvivalAnalysis
and CoxnetSurvivalAnalysis
will differ. In addition, CoxnetSurvivalAnalysis
supports fitting a path of coefficients for varying penalization strength alpha
without much computational overhead. CoxPHSurvivalAnalysis
will return the result for a single a penalization strength alpha
only. This is the main reason why constructors differ.
The fact that predict_survival_function
is not available for CoxnetSurvivalAnalysis
is an oversight which should be corrected.
from scikit-survival.
Since CoxnetSurvivalAnalysis
supports L1 and L2 regularization while CoxPHSurvivalAnalysis
supports only L2 regularization, by carefully matching the model parameters they should return the same results, and this is what the codes above try to do. So, why do CoxPHSurvivalAnalysis(alpha=0.5, n_iter=100000)
and CoxnetSurvivalAnalysis(alphas=[0.5], alpha_min_ratio=1, n_alphas=1, l1_ratio=1e-16, tol=1e-09, normalize=False)
not give the same results?
from scikit-survival.
You cannot completely eliminate the l1 penalty in CoxnetSurvivalAnalysis
, thus you aren't able to exactly match the result of CoxnetSurvivalAnalysis
.
from scikit-survival.
I think it makes sense to devote this issue to the different results between the two models, so I just changed the title back, and will close this thread now that it's resolved. Given that CoxnetSurvivalAnalysis
seems a more desirable implementation as I concluded above, it is important to have those methods that are so far only available to CoxPHSurvivalAnalysis
also available to CoxnetSurvivalAnalysis
. I will open a different issue for that.
from scikit-survival.
Great work @leihuang !
from scikit-survival.
So which coef was the best recommended used for model?
the first one or the last one?
g.coef_[:,0] or g.coef_[:,-1]
from scikit-survival.
from scikit-survival.
I see. Thank you very much!
from scikit-survival.
Related Issues (20)
- Survival Random Forest predict_survival_function does not scale with `n_jobs` HOT 1
- Clarify which metrics expect output of survival function vs output of cumulative hazard function HOT 1
- conf_type is not working in kaplan_meier_estimator: HOT 2
- How to ensemble predictions from ExtraSurvivalTrees models? HOT 1
- parallelization for GradientBoostingSurvivalAnalysis? HOT 1
- plotting a tree from estimators_[i] from RandomSurvivalForest.fit() HOT 6
- Possible memory leak for FastKernelSurvivalSVM HOT 1
- Fit does not throw exception if negative event times are passed
- Description of estimate parameter in integrated_brier_score is unclear HOT 1
- Ipcw estimation: Add small value for numerical stability HOT 1
- Description of estimate parameter in brier_score is unclear HOT 4
- Ability to suppress future warnings? HOT 1
- Possible improvement in documentation of Cumulative dynamic AUC HOT 1
- concordance_index_ipcw output inconsistent with survAUC package HOT 1
- Support Scikit-Learn 1.4 (stable version)
- 'cosine' kernel in FastKernelSurvivalSVM still in documentation but not working in 0.22.2
- Can't instantiate abstract class GradientBoostingSurvivalAnalysis with abstract methods _encode_y, _get_loss HOT 1
- SurvivalTree is handling sample_weight incorrectly
- [BUG] `GradientBoostingSurvivalAnalysis` - docstring/logic mismatch on possible `criterion` values HOT 1
- Wrong documentation in the Ridge section
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from scikit-survival.