Code Monkey home page Code Monkey logo

Comments (11)

ilia10000 avatar ilia10000 commented on May 27, 2024

The first error happens when the dimension of xtx and v don't match.
Can try to print the dimension of both to see if that's the case?

The second problem can happen because the optimization problem isn't very stable in its current form.
Possible solutions include relaxing some of the constraints (though this may lead to solutions that don't actually separate the classes as desired), changing the data, or using a different optimization method (e.g. fitting the soft-label prototypes using some sort of iterative method like gradient descent).

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

As per the discussion here, I modified the code to use ginv rather than solve. The following blocks were modified

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, ]
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
    xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in
  }
  
}
add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x, tol = sqrt(.Machine$double.eps))
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}

It looks like the "non-conformable" issue is arising due to changes in the second block - xtx_in <- ginv(t(x) %*% x, tol = sqrt(.Machine$double.eps)) is used instead of xtx_in <- solve(t(x) %*% x). The computed xtx_in via ginv is passed in beta <- beta_w(xtx_in, x, y, w) which calls xwx <- function(xtx_in, x, w) internally.

A point to note is that this behaviour seemingly occurs randomly - it happens more with higher number of lines needing to be generated but it can occur for 3 lines to be generated, for example. If there is a workaround for this issue regarding dimensionality, please let me know. Thanks!

from lo-shot.

ilia10000 avatar ilia10000 commented on May 27, 2024

Can you check the dimensions of xtx_in and t(v) and see if they are different in the cases where the non-conformable error comes up?

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

I captured some outputs below -

Running 1 / 336 with 4 classes
Fitting 1 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 16 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 5 and dim of xtx_in is 768"  
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 11 and dim of xtx_in is 768" 
Fitting 2 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 16 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 5 and dim of xtx_in is 768"  
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 11 and dim of xtx_in is 768" 
Fitting 3 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 17 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 1 and dim of xtx_in is 768"  
[2] "Dim of t(v): 768 and dim of xtx_in is 768"
R[write to console]: Error in xtx_in %*% t(v) : non-conformable arguments


Error in xtx_in %*% t(v) : non-conformable arguments
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xtx_in %*% t(v) : non-conformable arguments


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-49-7991660fada5>", line 69, in <module>
    lines = [line_order_no_endpoints(centroids=labeled_centroids[0], active_classes=np.array(line)) for line in find_lines_R_multiD(dat=labeled_training_data_np, labels=labeled_training_data[1] , dims=dimensions, centroids=labeled_centroids[0], k=required_lines)]
  File "<ipython-input-9-115a5284a88a>", line 478, in find_lines_R_multiD
    get_ipython().magic('R -i df -i k -i max_diff -i dims -o result1 result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)')
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2160, in magic
    return self.run_line_magic(magic_name, magic_arg_s)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2081, in run_line_magic
    result = fn(*args,**kwargs)
  File "<decorator-gen-119>", line 2, in R
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py", line 188, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 783, in R
    raise e
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 756, in R
    text_result, result, visible = self.eval(line)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 273, in eval
    warning_or_other_msg)
rpy2.ipython.rmagic.RInterpreterError: Failed to parse and evaluate line 'result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)'.
R error message: 'Error in xtx_in %*% t(v) : non-conformable arguments'

Running 2 / 336 with 9 classes
Fitting 1 lines...
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 14 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 
[1] "Dim of t(v): 768 and dim of xtx_in is 768"
[2] "Dim of t(v): 12 and dim of xtx_in is 768" 

It looks like in all cases (except the erroneous ones) the dimensions of t(v) are (768 x n) and for the erroneous cases, the dimensions seem to be reversed, ie: they are (1 x 768).

The complete R block I am using looks like this -

%%R
library(MASS)

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, ]
  print(paste0("Dim of t(v): ", dim(t(v)), " and dim of xtx_in is ", dim(xtx_in)))
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
      tryCatch({xtx_in + xtx_in %*% t(v) %*% solve(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in},
               error = function(e) {xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in})
  } 
}

xwy <- function(x, y, w) {
  if (sum(w == 0) == 0) {
    t(x) %*% y
  } else {
    t(x) %*% y - t(x[w == 0, ]) %*% y[w == 0]
  }
}

beta_w <- function(xtx_in, x, y, w) {
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

two_norm <- function(a, b) {
  sqrt(sum((a - b)^2))
}


group_classes <- function(data, label, k) {
  mu <- t(sapply(unique(label), function(ii) {
    colMeans(data[label == ii, , drop = F])
    }))
  
  mu_dist <- dist(mu)
  cluster <- cutree(hclust(mu_dist, method = "complete"), k = k)
  
  mu2 <- t(sapply(unique(cluster), function(ii) {
    colMeans(mu[cluster == ii, , drop = F])
  }))
  
  dist2 <- as.matrix(dist(mu2))
  
  jj <- 1
  while (jj <= length(unique(cluster))) {
    #print(length(unique(cluster)))
    #print(jj)
    if (table(cluster)[jj] == 1) {
      new_cluster <- which(rank(dist2[jj, ]) == 2)
      cluster[cluster == jj] <- new_cluster
    }
    jj <- jj + 1
  }
  # print(cluster)
  return(cluster)
}

furthest_classes <- function(data, label, classes) {
  mu <- t(sapply(classes, function(ii) {
    colMeans(data[label == ii,  , drop = F])
  }))
  mu_dist <- as.matrix(dist(mu))
  furthest <- which(mu_dist == max(mu_dist), arr.ind = T)[1, ]
  return(classes[furthest])
}

add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x)
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}


order_classes <- function(data, label, group) {
  # first two elements in group must be the furthest away.
  # this will be the case if group comes from recursive regression
  
  if (length(group) == 1) {
    return(group)
  } else {
    temp <- sapply(group[-1], function(ii) {
      a <- colMeans(data[label == group[1], , drop = F])
      b <- colMeans(data[label == ii, , drop = F])
      return(sum((a - b)^2))
    })
    return(c(group[1], group[-1][order(temp)]))
  }
}



recursive_reg <- function(data, label, k, max_diff = 1.5, keep_all = T) {
  
  # group the class-wise centroids into k groups
  init_group <- group_classes(data, label, k)
  k_new <- length(unique(init_group))
  
  #if (k_new == 1) {
  #  val <- list(group = order_classes(data, label, 1),
  #              line = lm())
  #}
  # for each of the k groups, find a line that incorporates
  # as many of the classes in that group as possible
  val <- lapply(sort(unique(init_group)), function(ii) {
    classes <- which(init_group == ii)
    # print(classes)
    if (keep_all) {
      temp <- add_classes(data, label, classes, max_diff)
      temp$group <- order_classes(data, label, temp$group)
      return(temp$group)
    } else {
      if (length(unique(classes)) == 1) {
        return(NULL)
      } else {
        temp <- add_classes(data, label, classes, max_diff)
        temp$group <- order_classes(data, label, temp$group)
        return(temp$group)
      }
    }
    #add_classes(data, label, classes, max_diff)
    #if (length(unique(classes)) == 1) {
    #  return(NULL)
    #} else {
    #  add_classes(data, label, classes, max_diff)
    #}
  })
  
  if (keep_all) {
    # if keep_all = T, keep lines from  single classes
    return(val)
  } else {
    # If keep_all = F, filter out groups with only one class
    return(val[lengths(val) != 0])
  }
}

from lo-shot.

ilia10000 avatar ilia10000 commented on May 27, 2024

In the xwx function, the first line is
v <- x[w == 0, ]

Nam suggests this should be changed to
v <- x[w == 0, , drop = F]

Can you try that out and see if it solves this?

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

Hi, thanks for the suggestion, however, unfortunately still seeing these issues. Let me know if there is a workaround or if there is a direction I can investigate in - though the cases are small in number, I would still like to get to the bottom of this as it'll help provide more accurate metrics.

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

Hi, is there any workaround for this? Getting a good number of results with the following exception -

R[write to console]: Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments


Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-12-2017da8319ba>", line 14, in <module>
    lines = [line_order_no_endpoints(centroids=labeled_centroids_np, active_classes=np.array(line)) for line in find_lines_R_multiD(dat=labeled_training_data_np, labels=labeled_training_data[1] , dims=dimensions, centroids=labeled_centroids_np, k=total_lines)]
  File "<ipython-input-4-5bd8ab90ad22>", line 499, in find_lines_R_multiD
    get_ipython().magic('R -i df -i k -i max_diff -i dims -o result1 result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)')
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2160, in magic
    return self.run_line_magic(magic_name, magic_arg_s)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2081, in run_line_magic
    result = fn(*args,**kwargs)
  File "<decorator-gen-119>", line 2, in R
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py", line 188, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 783, in R
    raise e
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 756, in R
    text_result, result, visible = self.eval(line)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 273, in eval
    warning_or_other_msg)
rpy2.ipython.rmagic.RInterpreterError: Failed to parse and evaluate line 'result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)'.
R error message: 'Error in xwx(xtx_in, x, w) %*% xwy(x, y, w) : non-conformable arguments'

The code in R looks is

%%R
library(MASS)

xwx <- function(xtx_in, x, w) {
  v <- x[w == 0, , drop = F]
  
  if (sum(w == 0) >= ncol(x) | sum(w == 0) == 0) {
    #print(sum(w == 1))
    tryCatch({solve(t(x[w == 1, ]) %*% x[w == 1, ])},
             error = function(e) {ginv(t(x[w == 1, ]) %*% x[w == 1, ])},
             finally = {})
  } else {
      tryCatch({xtx_in + xtx_in %*% t(v) %*% solve(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in},
               error = function(e) {xtx_in + xtx_in %*% t(v) %*% ginv(diag(nrow(v)) - v %*% xtx_in %*% t(v)) %*% v %*% xtx_in})
  } 
}

xwy <- function(x, y, w) {
  if (sum(w == 0) == 0) {
    t(x) %*% y
  } else {
    t(x) %*% y - t(x[w == 0, ]) %*% y[w == 0]
  }
}

beta_w <- function(xtx_in, x, y, w) {
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

two_norm <- function(a, b) {
  sqrt(sum((a - b)^2))
}


group_classes <- function(data, label, k) {
  mu <- t(sapply(unique(label), function(ii) {
    colMeans(data[label == ii, , drop = F])
    }))
  
  mu_dist <- dist(mu)
  cluster <- cutree(hclust(mu_dist, method = "complete"), k = k)
  
  mu2 <- t(sapply(unique(cluster), function(ii) {
    colMeans(mu[cluster == ii, , drop = F])
  }))
  
  dist2 <- as.matrix(dist(mu2))
  
  jj <- 1
  while (jj <= length(unique(cluster))) {
    #print(length(unique(cluster)))
    #print(jj)
    if (table(cluster)[jj] == 1) {
      new_cluster <- which(rank(dist2[jj, ]) == 2)
      cluster[cluster == jj] <- new_cluster
    }
    jj <- jj + 1
  }
  # print(cluster)
  return(cluster)
}

furthest_classes <- function(data, label, classes) {
  mu <- t(sapply(classes, function(ii) {
    colMeans(data[label == ii,  , drop = F])
  }))
  mu_dist <- as.matrix(dist(mu))
  furthest <- which(mu_dist == max(mu_dist), arr.ind = T)[1, ]
  return(classes[furthest])
}

add_classes <- function(data, label, classes, max_diff = 1.5) {
  
  # two furthest-apart classes in the group
  # we initially fit a line that pierces through both of their centroids
  furthest <- furthest_classes(data, label, classes)
  rest <- classes[!(classes %in% furthest)]
  
  x <- cbind(1, data[, -ncol(data)])
  y <- data[, ncol(data)]
  
  xtx_in <- ginv(t(x) %*% x)
  w <- ifelse(label %in% furthest, 1, 0)
  beta <- beta_w(xtx_in, x, y, w)

  while (length(rest) > 0) {
    
    # for the remaining classes, fit a regression line with it and
    # only classes currently in 'furthest' list
    beta_list <- lapply(rest, function(ii) {
      w <- ifelse(label %in% c(ii, furthest), 1, 0)
      return(beta_w(xtx_in, x, y, w))
    })
    
    # compare the distance between the regression line with the initial two furthest classes
    # and the newly fitted regression line
    distance <- sapply(beta_list, function(a) two_norm(a, beta))
    
    # stop if the smallest difference between the two regression lines is 
    # greater than the max tolerance
    if (all(distance > max_diff)) {
      rest <- integer(0)
    } else {
      
      # otherwise, include the class whose addition resulted in the smallest change
      # in the original regression line
      add <- which.min(distance)[1]
      furthest <- c(furthest, rest[add])
      rest <- rest[-add]
    }
  }
  
  return(list(group = unique(furthest), line = beta))
}


order_classes <- function(data, label, group) {
  # first two elements in group must be the furthest away.
  # this will be the case if group comes from recursive regression
  
  if (length(group) == 1) {
    return(group)
  } else {
    temp <- sapply(group[-1], function(ii) {
      a <- colMeans(data[label == group[1], , drop = F])
      b <- colMeans(data[label == ii, , drop = F])
      return(sum((a - b)^2))
    })
    return(c(group[1], group[-1][order(temp)]))
  }
}



recursive_reg <- function(data, label, k, max_diff = 1e-5, keep_all = T) {
  
  # group the class-wise centroids into k groups
  init_group <- group_classes(data, label, k)
  k_new <- length(unique(init_group))
  
  #if (k_new == 1) {
  #  val <- list(group = order_classes(data, label, 1),
  #              line = lm())
  #}
  # for each of the k groups, find a line that incorporates
  # as many of the classes in that group as possible
  val <- lapply(sort(unique(init_group)), function(ii) {
    classes <- which(init_group == ii)
    # print(classes)
    if (keep_all) {
      temp <- add_classes(data, label, classes, max_diff)
      temp$group <- order_classes(data, label, temp$group)
      return(temp$group)
    } else {
      if (length(unique(classes)) == 1) {
        return(NULL)
      } else {
        temp <- add_classes(data, label, classes, max_diff)
        temp$group <- order_classes(data, label, temp$group)
        return(temp$group)
      }
    }
    #add_classes(data, label, classes, max_diff)
    #if (length(unique(classes)) == 1) {
    #  return(NULL)
    #} else {
    #  add_classes(data, label, classes, max_diff)
    #}
  })
  
  if (keep_all) {
    # if keep_all = T, keep lines from  single classes
    return(val)
  } else {
    # If keep_all = F, filter out groups with only one class
    return(val[lengths(val) != 0])
  }
}

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

It looks like the failing matrix multiplication has dimensions

[1] 1 1
[1] 768   1

Is there a workaround for this? Please let me know if so. I added a print statement in

beta_w <- function(xtx_in, x, y, w) {
  print(dim(xwx(xtx_in, x, w)))
  print(dim(xwy(x, y, w)))
  xwx(xtx_in, x, w) %*% xwy(x, y, w)
}

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

The dataframe looks like

             0         1         2         3         4         5         6  \
0    -0.285893  0.100989 -0.086276 -0.055642  0.386805 -0.071219  0.539185   
1     0.229483  0.433564 -0.169166 -0.373750  0.058643  0.156620  0.517695   
2     0.507505  0.397592 -0.533612  0.012184  0.141688  0.498529  0.245283   
3     0.108507 -0.343783  0.011003 -0.334807  0.436632 -0.295442 -0.064393   
4    -0.092536  0.101324  0.381015 -0.029415 -0.037573  0.209534  0.321506   
...        ...       ...       ...       ...       ...       ...       ...   
6812  0.081752 -0.076248  0.552696 -0.248589  0.123113 -0.457422 -0.211575   
6813 -0.002597  0.236071  0.211104 -0.253087  0.099704 -0.125046  0.015364   
6814  0.292048 -0.054196  0.459010 -0.343681  0.242175 -0.307340  0.013871   
6815  0.376926 -0.285080  0.137277 -0.225266  0.508968  0.213983  0.056722   
6816  0.009812 -0.133948  0.004090  0.038559  0.166507 -0.004390 -0.031691   

             7         8         9  ...       759       760       761  \
0     0.226831 -0.631589  0.104056  ...  0.060103  0.287790 -0.302842   
1     0.084167 -0.619955 -0.031069  ...  0.312966  0.111479 -0.320775   
2     0.426305  0.059289 -0.223517  ... -0.560820  0.158381 -0.133150   
3     0.121892 -0.102805  0.160829  ...  0.127226 -0.128728 -0.527352   
4     0.134741 -0.072664  0.136189  ... -0.109515  0.112521  0.195295   
...        ...       ...       ...  ...       ...       ...       ...   
6812  0.433101  0.015748  0.256297  ...  0.014548 -0.457582 -0.392458   
6813  0.658662  0.135492 -0.014497  ...  0.060105 -0.035580 -0.613886   
6814  0.368436  0.219049 -0.113884  ...  0.215659  0.008811 -0.413046   
6815  0.212861  0.120595 -0.306513  ... -0.241388  0.224027 -0.231764   
6816  0.698086 -0.115483 -0.004676  ... -0.162783  0.021009 -0.512451   

           762       763       764       765       766       767  \
0    -0.304429 -0.145017 -0.040142 -0.040846  0.040223  0.044410   
1    -0.097893 -0.072137 -0.256350 -0.168176 -0.390138  0.014747   
2     0.284558 -0.244452 -0.211685  0.412036  0.448472  0.136268   
3    -0.558983 -0.240184 -0.069946 -0.078860  0.027040 -0.251058   
4     0.073837 -0.361740  0.242135 -0.225458 -0.043142 -0.317183   
...        ...       ...       ...       ...       ...       ...   
6812 -0.004268 -0.017532  0.337242 -0.226368 -0.079176  0.671449   
6813 -0.273185 -0.061682  0.669058 -0.042069  0.012758  0.943183   
6814 -0.668200  0.115001  0.616219 -0.091257 -0.084448  0.729236   
6815  0.136355 -0.046832  0.212380 -0.251151  0.809596  0.316020   
6816 -0.257796 -0.306105  0.732763 -0.280010 -0.016456  0.494144   

      My Hopes And Dreams  
0                       0  
1                       0  
2                       0  
3                       0  
4                       0  
...                   ...  
6812                 1592  
6813                 1592  
6814                 1592  
6815                 1592  
6816                 1592  

[6817 rows x 769 columns]

from lo-shot.

avyavkumar avatar avyavkumar commented on May 27, 2024

If possible, some documentation about the different functions like xwx, xwy would be helpful to debug issues - it reduces reliance on the original authors of the paper.

from lo-shot.

ilia10000 avatar ilia10000 commented on May 27, 2024

Nam sent me a fix with additional documentation that I pushed to the repo as a separate file just now. Let me know if it helps.

from lo-shot.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.