Code Monkey home page Code Monkey logo

towards_data_science's People

Contributors

j-adamczyk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

opoyraz lkafle

towards_data_science's Issues

ValueError: Cannot index with multidimensional key

Hi Jakub,

Thanks for providing your KNN implementation using faiss. I'm working with a large dataset (566602 rows ร— 20 columns) and KNeighborsClassifier took way too long, so I was hoping your implementation would help.

The problem is, I'm applying one-hot encoding to my categorical features and this seems to be causing the following error in the predict method of the classifier:
ValueError: Cannot index with multidimensional key

To demonstrate this, here's a code sample that results in the same error on a different dataset:

import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split

X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)

numeric_features = ['age', 'fare']
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())])

categorical_features = ['embarked', 'sex', 'pclass']
categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))])

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)])

clf = Pipeline(steps=[('preprocessor', preprocessor),
                      ('classifier', FaissKNeighbors())])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

clf.fit(X_train, y_train)
clf.predict(X_test)

Here is the full stack trace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-35-373ecdcebcf0> in <module>()
     29 
     30 clf.fit(X_train, y_train)
---> 31 clf.predict(X_test)

6 frames
/usr/local/lib/python3.6/dist-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
    114 
    115         # lambda, but not partial, allows help() to work with update_wrapper
--> 116         out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
    117         # update the docstring of the returned function
    118         update_wrapper(out, self.fn)

/usr/local/lib/python3.6/dist-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
    418         for _, name, transform in self._iter(with_final=False):
    419             Xt = transform.transform(Xt)
--> 420         return self.steps[-1][-1].predict(Xt, **predict_params)
    421 
    422     @if_delegate_has_method(delegate='_final_estimator')

<ipython-input-24-c1e7d3808324> in predict(self, X)
     15     def predict(self, X):
     16         distances, indices = self.index.search(X.astype(np.float32), k=self.k)
---> 17         votes = self.y[indices]
     18         predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
     19         return predictions

/usr/local/lib/python3.6/dist-packages/pandas/core/series.py in __getitem__(self, key)
    908             key = check_bool_indexer(self.index, key)
    909 
--> 910         return self._get_with(key)
    911 
    912     def _get_with(self, key):

/usr/local/lib/python3.6/dist-packages/pandas/core/series.py in _get_with(self, key)
    941         if key_type == "integer":
    942             if self.index.is_integer() or self.index.is_floating():
--> 943                 return self.loc[key]
    944             else:
    945                 return self._get_values(key)

/usr/local/lib/python3.6/dist-packages/pandas/core/indexing.py in __getitem__(self, key)
   1766 
   1767             maybe_callable = com.apply_if_callable(key, self.obj)
-> 1768             return self._getitem_axis(maybe_callable, axis=axis)
   1769 
   1770     def _is_scalar_access(self, key: Tuple):

/usr/local/lib/python3.6/dist-packages/pandas/core/indexing.py in _getitem_axis(self, key, axis)
   1950 
   1951                 if hasattr(key, "ndim") and key.ndim > 1:
-> 1952                     raise ValueError("Cannot index with multidimensional key")
   1953 
   1954                 return self._getitem_iterable(key, axis=axis)

ValueError: Cannot index with multidimensional key

As you can see, the error is happening on line 17 at votes = self.y[indices]. I have tested the implementation on the Iris dataset without any preprocessing and it works fine, so I believe it's related to the one-hot encoding. Please let me know if you have a fix or if one-hot encoding is not necessary for this implementation. Thanks again!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.