今回は劣勾配法を使って,L1正則化最小二乗学習(Lasso)を実装してみたいと思う.最適化問題を,以下のように定義する.
第一項はで微分可能なので,その勾配を使い,第二項は列勾配を使う.これで更新式を出すと,
となる.まとめると,
となる.これを実装したクラスが以下.
class Lasso(object):
def __init__(self, eta=1e-4, reg=1, max_iter=1000):
self.eta = eta
self.reg = reg
self.max_iter = max_iter
def fit(self, X, y):
w = np.zeros(X.shape[1])
for _ in range(self.max_iter):
g = - X.T.dot(y - X.dot(w)) + self.reg * np.sign(w)
w = w - self.eta * g
self.w = w
return self
def predict(self, X):
return X.dot(self.w)
1-d example
1次元で実験してみる.
import numpy as np
import matplotlib.pyplot as plt
random_state = 0
rnd = np.random.RandomState(random_state)
n, d = 100, 2
sigma = 0.5
x = np.linspace(-2, 2, n)
y = x + rnd.normal(0, sigma, size=n)
model = Lasso()
model.fit(np.c_[x], y)
y_pred = model.predict(np.c_[x])
plt.scatter(x, y)
plt.plot(x, y_pred, "r-", linewidth=3)
結果は以下.きちんと学習されているっぽい.
100-d example
次にL1正則化を施したことにより,解がスパースになることを確認するために,100次元のデータで実験する.コードは例えば以下.
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
X, y = make_regression(n_samples=500, n_features=100, n_informative=10, random_state=random_state)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)
model = Lasso()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(np.sqrt(mean_squared_error(y_test, y_pred))) # 0.008980784567524594
きちんと解がスパースになっていることを確認してみる.例えば,
np.around(model.w, 3)
とすると,
array([ 0. , -0. , -0. , 0. , -0. , 0. , 0. , -0. ,
-0. , 0. , 0. , -0. , 0. , 0. , -0. , -0. ,
0. , 0. , -0. , 0. , 0. , -0. , 99.673, 0. ,
0. , -0. , 0. , 0. , -0. , 0. , 41.788, 0. ,
-0. , -0. , 0. , -0. , 0. , -0. , 0. , -0. ,
-0. , 74.23 , -0. , 0. , 0. , 0. , 0. , 0. ,
-0. , -0. , 0. , -0. , 0. , 0. , 0. , 0. ,
-0. , -0. , 0. , -0. , 0. , 37.469, -0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , -0. , -0. ,
30.374, 44.89 , 0. , 50.938, 0. , 0. , -0. , -0. ,
-0. , 0. , 0. , 98.669, -0. , 0. , 27.71 , 0. ,
0. , -0. , -0. , -0. , -0. , 86.562, -0. , 0. ,
-0. , -0. , 0. , 0. ])
となって,n_informativeの数だけ非ゼロになっている.
Written with StackEdit.
0 件のコメント:
コメントを投稿