官术网_书友最值得收藏!

Testing the model

In order to test the generalization performance of the model, we calculate the mean squared error on the test data:

In [11]: y_pred = linreg.predict(X_test)
In [12]: metrics.mean_squared_error(y_test, y_pred)
Out[12]: 15.010997321630166

We note that the mean squared error is a little lower on the test set than the training set. This is good news, as we care mostly about the test error. However, from these numbers, it is really hard to understand how good the model really is. Perhaps it's better to plot the data:

In [13]: plt.figure(figsize=(10, 6))
... plt.plot(y_test, linewidth=3, label='ground truth')
... plt.plot(y_pred, linewidth=3, label='predicted')
... plt.legend(loc='best')
... plt.xlabel('test data points')
... plt.ylabel('target value')
Out[13]: <matplotlib.text.Text at 0x7ff46783c7b8>

This produces the following figure:

Linear regression model

This makes more sense! Here we see the ground truth housing prices for all test samples in blue and our predicted housing prices in red. Pretty close, if you ask me. It is interesting to note though that the model tends to be off the most for really high or really low housing prices, such as the peak values of data point 12, 18, and 42. We can formalize the amount of variance in the data that we were able to explain by calculating R squared:

In [14]: plt.plot(y_test, y_pred, 'o')
... plt.plot([-10, 60], [-10, 60], 'k--')
... plt.axis([-10, 60, -10, 60])
... plt.xlabel('ground truth')
... plt.ylabel('predicted')

This will plot the ground truth prices, y_test, on the x axis, and our predictions, y_pred, on the y axis. We also plot a diagonal line for reference (using a black dashed line, 'k--'), as we will see soon. But we also want to display the R2 score and mean squared error in a text box:

...      scorestr = r'R$^2$ = %.3f' % linreg.score(X_test, y_test)
... errstr = 'MSE = %.3f' % metrics.mean_squared_error(y_test, y_pred)
... plt.text(-5, 50, scorestr, fontsize=12)
... plt.text(-5, 45, errstr, fontsize=12)
Out[14]: <matplotlib.text.Text at 0x7ff4642d0400>

This will produce the following figure, and is a professional way of plotting a model fit:

Model predictions versus ground truth

If our model was perfect, then all data points would lie on the dashed diagonal, since y_pred would always be equal to y_true. Deviations from the diagonal indicate that the model made some errors, or that there is some variance in the data that the model was not able to explain. Indeed, R2 indicates that we were able to explain 76 percent of the scatter in the data, with a mean squared error of 15.011. These are some hard numbers we can use to compare the linear regression model to some more complicated ones.

主站蜘蛛池模板: 镇远县| 黎平县| 陆丰市| 平安县| 镶黄旗| 巴林左旗| 乾安县| 呼伦贝尔市| 平江县| 建宁县| 荃湾区| 德保县| 滁州市| 镇沅| 北碚区| 宝山区| 清原| 德阳市| 乐安县| 台南市| 青神县| 锦州市| 清镇市| 巴塘县| 芜湖县| 荔波县| 宁夏| 宿迁市| 镇坪县| 嫩江县| 平原县| 镇康县| 洛隆县| 旺苍县| 喜德县| 万盛区| 浠水县| 阿图什市| 广州市| 都江堰市| 贵州省|