If the loss function is non-convex, the Hessian around a chosen parameter value might have non-negative eigenvalues. To solve this, a small multiple of the identity is added to the Hessian at that point in a 2nd order Taylor approximation to the loss. This is then used to compute the influence instead of the actual loss.
See section 4 of http://proceedings.mlr.press/v70/koh17a.html.
For non-convex losses one can simple neglect the indefinite part of the matrix involving terms. This would correspond to using the Gauss-Newton approximation from https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf.
Although the interface is not fully specified and we might modify in a way similar to the following. The same pattern could achieved in a functional style. Here is an example for non-differentiable models:
class SupervisedModel(Protocol):
"""Pedantic: only here for the type hints."""
def fit(self, x: np.ndarray, y: np.ndarray):
pass
def predict(self, x: np.ndarray) -> np.ndarray:
pass
def score(self, x: np.ndarray, y: np.ndarray) -> float:
pass
def params(self) -> np.ndarray:
pass
class PyTorchSurrogateModel:
"""Wrap non-differentiable model with a surrogate objective L(\theta)."""
def __init__(self, base_model: SupervisedModel, surrogate_objective: Callable[[torch.tensor], torch.tensor]):
self.__base_model = base_model
self.__surrogate = surrogate_objective
# ================================================
# implement grad and mvp using torch and L(\theta)
# ================================================
def params(self):
return self.__base_model.params()
def fit(self, x: np.ndarray, y: np.ndarray):
return self.__base_model.fit(x, y)
def predict(self, x: np.ndarray) -> np.ndarray:
return self.__base_model.predict(x)
def score(self, x: np.ndarray, y: np.ndarray) -> float:
return self.__base_model.score(x, y)