helmholtzai-consultants-munich / pysddr Goto Github PK
View Code? Open in Web Editor NEWA python package for semi-structured deep distributional regression
License: MIT License
A python package for semi-structured deep distributional regression
License: MIT License
From my understanding, there is a discrepancy between the orthogonalization as presented in the SDDR paper, and the present implementation. Indeed, the proof of Lemma 1 seems to require an orthogonalization according to all structured components (no matter if those explicitely enter the network trunk or not). The present implementation, however, only performs the orthogonalization according to the structured effects that explicitely enter the deep neural network trunk (i.e. if x1 is a structured covariate that also serves as input to the network trunk, then U_hat is orthogonalized with respect to x1. If no structured covariate enters the network trunk, then no orthogonalization takes place ).
Am I missing anything in the code? Or is there any reason for this behavior? If yes, sorry for bothering! Otherwise, I believe the fix only requires a slight adjustment of the forward function in the sddrnetwork.py file (see pull request)
Dear all, I have a question regarding using embedding for categorical variable in the deep neural network (DNN) part. When I want to include categorical variables in DNN part with nn.Embedding
, I always get errors like
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
But when using the R version of the package, this problem does occur. Could you please help me? Thanks a lot in advance.
Here are the codes:
For Python
# source of data: https://archive.ics.uci.edu/ml/datasets/breast+cancer
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.optim as optim
from sddr import Sddr
import warnings
# problem: predict whether tumor recurs or not
data = pd.read_csv(
'breast_cancer.data',
names=['class', 'age', 'menopause', 'tumor_size', 'inv_nodes', 'node-caps', 'deg_malig', 'breast', 'breast_quad', 'irradiat']
)
X = data[['age', 'tumor_size']] # both `age` and `tumor_size` are discretized continuous variables
y = (data[['class']] == 'recurrence-events').astype(np.int64) # 0: tumor doesn't recur; 1: tumor recurs
# encode 'tumor_size' as integer
le = LabelEncoder()
X = X.apply(lambda x: le.fit_transform(x))
# split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=2022, stratify=y)
# calculate the number of levels
num_emb = [X[col].nunique() for col in X.columns] # `age` has 6 levels, `tumor_size` has 11 levels
# define the function to get the dimension of embedding = min{#levels/2, 50}
def get_emb_dim(num_emb):
return min(50, num_emb//2)
get_emb_dim = np.vectorize(get_emb_dim)
# calculate the dimension of embedding for each categorical variable
emb_dim = get_emb_dim(num_emb) # use a 1x3 vector to represent `age` and 1x5 vector to represent `tumor_size`
# define the distribution
distribution = 'Bernoulli_prob'
# define the formulas
formulas = {'probs': 'd1(age) + d2(tumor_size)'}
# define the architecture of deep neural networks
deep_models_dict = {
'd1': {
'model': nn.Sequential(nn.Embedding(num_emb[0], emb_dim[0]), nn.ReLU(), nn.Linear(emb_dim[0], 3)),
'output_shape': 3
},
'd2': {
'model': nn.Sequential(nn.Embedding(num_emb[1], emb_dim[1]), nn.ReLU(), nn.Linear(emb_dim[1], 3)),
'output_shape': 3
}
}
# define output directory
output_dir = '.'
# define the hyperparameters
train_params= {
'batch_size': X_train.shape[0],
'epochs': 50,
'degrees_of_freedom': {'probs': 2},
'optimizer': optim.RMSprop,
'val_split': 0.2
}
# initiate the instance
sddr = Sddr(
distribution=distribution,
formulas=formulas,
deep_models_dict=deep_models_dict,
train_parameters=train_params,
output_dir=output_dir
)
# ignore warnings
warnings.filterwarnings('ignore')
# train the model
sddr.train(
structured_data=X_train,
target=y_train,
plot=True
)
For R (Codes are written by Victor Medina-Olivares)
library(deepregression)
library(tidyverse)
data <- read_csv("../data/dataset_13_breast-cancer.csv")
data <- data %>%
select(c("age", "tumor-size", "Class")) %>%
mutate(y = ifelse(Class=="recurrence-events",1,0),
age_enc = as.numeric(as.factor(age))-1,
tumor_enc = as.numeric(as.factor(`tumor-size`))-1)
y <- data$y
input_age <- length(unique(data$age_enc))
output_age <- min(50, input_age%/%2)
input_tumor <- length(unique(data$tumor_enc))
output_tumor <- min(50, input_tumor%/%2)
d1 <- function(x){
x %>%
layer_embedding(input_age, output_age) %>%
layer_dense(units = 3, activation = "linear")
}
d2 <- function(x){
x %>%
layer_embedding(input_tumor, output_tumor) %>%
layer_dense(units = 3, activation = "linear")
}
mod <- deepregression(y = y,
data = data,
list_of_formulas = list(
prob = ~ d1(age_enc) +d2(tumor_enc)),
family = "bernoulli_prob",
list_of_deep_models = list(d1 = d1, d2 = d2)
)
mod %>%
fit(epochs = 5,
verbose = T,
view_metrics = T,
validation_split = 0.2)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.