juliaai / mljdecisiontreeinterface.jl Goto Github PK
View Code? Open in Web Editor NEWLicense: MIT License
License: MIT License
In response to: https://github.com/bensadeghi/DecisionTree.jl/issues/147
julia> fitted_params(mach).encoding
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
"virginica" => 0x00000003
"setosa" => 0x00000001
"versicolor" => 0x00000002
This is really backwards. To interpret the DecisionTree.jl tree object, one wants the label given the reference integer.
This issue is used to trigger TagBot; feel free to unsubscribe.
If you haven't already, you should update your TagBot.yml
to include issue comment triggers.
Please see this post on Discourse for instructions and more details.
The DecisionTree.jl package allows to pass a state of the Random Number Generator to its build_forest
function see here
This makes sense because building of a forest of decision trees needs some random numbers to select variables. To make results reproducible it would be nice to allow to pass such an rng
argument also to the fit
method in
MLJDecisionTreeInterface.jl.
Or is there another "workaround" to achieve this?
It's now possible to plot a tree using the TreeRecipe.jl package but the workflow is not very user-friendly, as trees first need to be wrapped.
Currently the raw decision tree is exposed as fitted_params(mach).tree
. I propose we pre-wrap this object (with the feature names already embedded) and add an extra field fitted_params(mach).raw_tree
for the original unwrapped object.
Then plotting a tree would be as simple as
using TreeRecipe, Plots
tree = fitted_params(mach).tree
plot(tree)
Thoughts anyone?
Related: #23
Version 0.3.1 (to be yanked) and version 0.4 introduced a breaking change to MLJModelInterface.fit
and MLJModelInterface.predict
for all 5 models. This change only effects developers who directly call those functions. Regular MLJ users who interact through the usual "machine" interface are not affected.
What changed is that the models mentioned now implement the MLJModelInterface data-front end. The most likely reason for breakage is that fit
and predict
are not being called with the model-specific form of data generated by the reformat
method. This post describes a backwards-compatible fix.
reformat
to your fit
/predict
callsIf you are not already using reformat
in your fit
and predict
calls then, where you previously made a call
MMI.fit(model, verbosity, data...) # MMI = MLJModelInterface
you instead want
MMI.fit(model, verbosity, MMI.reformat(model, data...)...)
And instead of
MMI.predict(model, fitresult, Xnew)
you want
MMI.predict(model, fitresult, MMI.reformat(model, Xnew)...)
You have backwards compatibility because the fallback for reformat
just slurps the data. This also means you can change these calls for all models (not just the DecisionTree ones).
If you subsample reformat
ted data before passing to fit
or predict
, you should always use subsampled_data = selectrows(model, I, reformatted_data...)
where I
is the indices for subsampling.
This suggestion was made in another post.
I see the sk-learn default is also 100.
Any objections?
This is actually possible, because DecisionTree.print_tree()
has an option to pass the feature names: https://github.com/bensadeghi/DecisionTree.jl/blob/3fcb5b083e9abf45773ad1f22945473a7cc4ef89/src/DecisionTree.jl#L86
cc @roland-KA
Also, if not already there, add MMI.iteration_parameter
.
From JuliaAI/DecisionTree.jl#182
@OkonSamuel you said you were working on it? Mostly opening this issue so I can track it, I'm eager to put it to use!
It seems the recent changes to the implementation of fit
in 0.3.1, in particular the additional arguments that are now required, broke downstream tests in MCMCChains: TuringLang/MCMCChains.jl#400 and TuringLang/MCMCChains.jl#402
My main question before making any additional changes downstream is: Is that a bug in MLJDecisionTreeInterface or a problem with the implementation in MCMCChains (or rather MCMCDiagnosticTools)?
The following example shows how to manually plot the trees learned in DecisionTree.jl:
https://github.com/JuliaAI/TreeRecipe.jl/blob/master/examples/DecisionTree_iris.jl
Currently, the way to integrate a plot recipe in MLJ.jl is not documented, but is sketched in this comment.
So, can we somehow put this together to arrange that a workflow like this generates a plot of a decision tree?
edited again (x2):
using MLJBase
using Plots # <---- added in edit
import MLJDecisionTreeInterface
tree = MLJDecisionTreeInterface.DecisionTreeClassifier()
X, y = @load_iris
mach = machine(tree, X, y) |> fit!
plot(mach, 0.8, 0.7; size = (1400,600))) # <---- added in edit
Note: It used to be that you made RecipesBase.jl your dependency, to avoid a full Plots.jl dependency. But now the recipes live in Plots.jl and you are expected to make Plots.jl a weak dependency. You can see an example of this here.
julia> using MLJDecisionTreeInterface
[ Info: Precompiling MLJDecisionTreeInterface [c6f25543-311c-4c74-83dc-3ea6d1015661]
ERROR: LoadError: MethodError: no method matching metadata_model(::Type{MLJDecisionTreeInterface.DecisionTreeClassifier}; input_scitype=ScientificTypesBase.Table{<:Union{AbstractVector{<:ScientificTypesBase.Count}, AbstractVector{<:ScientificTypesBase.OrderedFactor}, AbstractVector{<:ScientificTypesBase.Continuous}}}, target_scitype=AbstractVector{<:ScientificTypesBase.Finite}, human_name="CART decision tree classifier", load_path="MLJDecisionTreeInterface.DecisionTreeClassifier")
Closest candidates are:
metadata_model(::Any; input, target, output, weights, descr, path, input_scitype, target_scitype, output_scitype, supports_weights, docstring, load_path) at C:\Users\songroom.julia\packages\MLJModelInterface\txhfr\src\metadata_utils.jl:101 got unsupported keyword argument "human_name"
Stacktrace:
[1] kwerr(::NamedTuple{(:input_scitype, :target_scitype, :human_name, :load_path), Tuple{UnionAll, UnionAll, String, String}}, ::Function, ::Type)
@ Base .\error.jl:163
[2] top-level scope
@ C:\Users\songroom.julia\packages\MLJDecisionTreeInterface\WijC5\src\MLJDecisionTreeInterface.jl:286
[3] include
@ .\Base.jl:418 [inlined]
[4] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt64}}, source::Nothing)
@ Base .\loading.jl:1318
[5] top-level scope
@ none:1
[6] eval
@ .\boot.jl:373 [inlined]
[7] eval(x::Expr)
@ Base.MainInclude .\client.jl:453
[8] top-level scope
@ none:1
in expression starting at C:\Users\songroom.julia\packages\MLJDecisionTreeInterface\WijC5\src\MLJDecisionTreeInterface.jl:1
ERROR: Failed to precompile MLJDecisionTreeInterface [c6f25543-311c-4c74-83dc-3ea6d1015661] to C:\Users\songroom.julia\compiled\v1.7\MLJDecisionTreeInterface\jl_DEAB.tmp.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] compilecache(pkg::Base.PkgId, path::String, internal_stderr::IO, internal_stdout::IO, ignore_loaded_modules::Bool)
@ Base .\loading.jl:1466
[3] compilecache(pkg::Base.PkgId, path::String)
@ Base .\loading.jl:1410
[4] _require(pkg::Base.PkgId)
@ Base .\loading.jl:1120
[5] require(uuidkey::Base.PkgId)
@ Base .\loading.jl:1013
[6] require(into::Module, mod::Symbol)
@ Base .\loading.jl:997
julia> versioninfo()
Julia Version 1.7.0
Commit 3bf9d17731 (2021-11-30 12:12 UTC)
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
JULIA_PKG_SERVER = https://pkg.julialang.org
The docs for DecisionTreeClassifier
hyperparameter :n_subfeatures
state:
n_subfeatures=0: number of features to select at random (0 for all, -1 for square root of number of features)
However the value -1
leads to the error "number of features -1 must be >= zero ". I'm not sure if the issue should be handled at this interface or in DecisionTree.jl itself as the function build_forest()
has
if n_subfeatures == -1
n_features = size(features, 2)
n_subfeatures = round(Int, sqrt(n_features))
end
while the function build_tree()
does not. See this link for the classifier and this other link for the regressor.
Properly speaking I think smoothing is something that should be implemented using a transformer, or maybe wrapper. Laplace smoothing is pretty standard but not what is implemented here.
When training with Multiclass
inputs, I get warnings like
┌ Warning: The scitype of `X`, in `machine(model, X, ...)` is incompatible with `model=DeterministicTunedModel{Grid,…}`:
│ scitype(X) = Table{Union{AbstractVector{Continuous}, AbstractVector{Count}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{6}}}}
│ input_scitype(model) = Table{<:Union{AbstractVector{<:Continuous}, AbstractVector{<:Count}, AbstractVector{<:OrderedFactor}}}.
└ @ MLJBase ~/.julia/packages/MLJBase/QXObv/src/machines.jl:133
Is this intended? I don't see why Multiclass
would be included here, in fact, if I do models(matching(Xtrain, ytrain))
on my inputs, the models I'm attempting to use indeed show up:
◖◗ models(matching(Xtrain, ytrain))
4-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:
(name = ConstantRegressor, package_name = MLJModels, ... )
(name = DecisionTreeRegressor, package_name = BetaML, ... )
(name = DeterministicConstantRegressor, package_name = MLJModels, ... )
(name = RandomForestRegressor, package_name = BetaML, ... )
So, it's become clear to me that the prediction labels can be wrong.
Say I have the following data
x | y |
---|---|
1 | a |
1 | a |
1 | a |
2 | b |
2 | b |
2 | b |
A decision tree trained to predict y
form x
will have a very easy time.
So what happens when I get the predicted probabilities?
x | y | p(y=a) | p(y=b) |
---|---|---|---|
1 | a | 0 | 1 |
1 | a | 0 | 1 |
1 | a | 0 | 1 |
2 | b | 1 | 0 |
2 | b | 1 | 0 |
2 | b | 1 | 0 |
Errr..., this isn't right.
julia> using MLJ, MLJDecisionTreeInterface, DataFrames
julia> Xind, yind = DataFrame(i=[1, 1, 1, 2, 2, 2]), ["a", "a", "a", "b", "b", "b"];
julia> tmach = machine(DecisionTreeClassifier(), Xind, coerce(yind, Multiclass))
untrained Machine; caches model-specific representations of data
model: DecisionTreeClassifier(…)
args:
1: Source @508 ⏎ Table{AbstractVector{Count}}
2: Source @186 ⏎ AbstractVector{Multiclass{2}}
julia> fit!(tmach)
[ Info: Training machine(DecisionTreeClassifier(max_depth = 4, …), …).
trained Machine; caches model-specific representations of data
model: DecisionTreeClassifier(max_depth = 4, …)
args:
1: Source @508 ⏎ Table{AbstractVector{Count}}
2: Source @186 ⏎ AbstractVector{Multiclass{2}}
julia> MLJ.predict(tmach, Xind)
6-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, String, UInt32, Float64}:
UnivariateFinite{Multiclass{2}}(a=>0.0, b=>1.0)
UnivariateFinite{Multiclass{2}}(a=>0.0, b=>1.0)
UnivariateFinite{Multiclass{2}}(a=>0.0, b=>1.0)
UnivariateFinite{Multiclass{2}}(a=>1.0, b=>0.0)
UnivariateFinite{Multiclass{2}}(a=>1.0, b=>0.0)
UnivariateFinite{Multiclass{2}}(a=>1.0, b=>0.0)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.