gam
- 一般化加法モデル
- モデルを作成する中で説明変数が目的変数にどのような寄与をしたのか非線形に解釈できるライブラリ
- 性能とかではなくて解釈に対して使える
install
$ python -m pip install pygam
usage
from pygam import GAM, s
...
model = GAM()
model.fit(input_df, y)
def plot_splines(gam):
fig, axes = plt.subplots(ncols=len(input_df.columns), figsize=(14, 5), sharey=True)
axes = np.array(axes).flatten()
for i, (ax, title, p_value) in enumerate(zip(axes, input_df.columns, gam.statistics_['p_values'])):
XX = gam.generate_X_grid(term=i)
ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX))
ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX, width=.95)[1], c='r', ls='--')
ax.axhline(0, c='#cccccc')
ax.set_title("{0:} (p={1:.2})".format(title, p_value))
ax.set_yticks([])
ax.grid()
fig.tight_layout()
return fig, axes
plot_splines(model)