- Python:Data Analytics and Visualization
- Phuong Vo.T.H Martin Czygan Ashish Kumar Kirthi Raman
- 597字
- 2021-07-09 18:51:47
Measuring prediction performance
We have already seen that the machine learning process consists of the following steps:
- Model selection: We first select a suitable model for our data. Do we have labels? How many samples are available? Is the data separable? How many dimensions do we have? As this step is nontrivial, the choice will depend on the actual problem. As of Fall 2015, the scikit-learn documentation contains a much appreciated flowchart called choosing the right estimator. It is short, but very informative and worth taking a closer look at.
- Training: We have to bring the model and data together, and this usually happens in the fit methods of the models in scikit-learn.
- Application: Once we have trained our model, we are able to make predictions about the unseen data.
So far, we omitted an important step that takes place between the training and application: the model testing and validation. In this step, we want to evaluate how well our model has learned.
One goal of learning, and machine learning in particular, is generalization. The question of whether a limited set of observations is enough to make statements about any possible observation is a deeper theoretical question, which is answered in dedicated resources on machine learning.
Whether or not a model generalizes well can also be tested. However, it is important that the training and the test input are separate. The situation where a model performs well on a training input but fails on an unseen test input is called overfitting, and this is not uncommon.
The basic approach is to split the available data into a training and test set, and scikit-learn helps to create this split with the train_test_split
function.
We go back to the Iris dataset and perform SVC again. This time we will evaluate the performance of the algorithm on a training set. We set aside 40 percent of the data for testing:
>>> from sklearn.cross_validation import train_test_split >>> X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.4, random_state=0) >>> clf = SVC() >>> clf.fit(X_train, y_train)
The score function returns the mean accuracy of the given data and labels. We pass the test set for evaluation:
>>> clf.score(X_test, y_test) 0.94999999999999996
The model seems to perform well, with about 94 percent accuracy on unseen data. We can now start to tweak model parameters (also called hyper parameters) to increase prediction performance. This cycle would bring back the problem of overfitting. One solution is to split the input data into three sets: one for training, validation, and testing. The iterative model of hyper-parameters tuning would take place between the training and the validation set, while the final evaluation would be done on the test set. Splitting the dataset into three reduces the number of samples we can learn from as well.
Cross-validation (CV) is a technique that does not need a validation set, but still counteracts overfitting. The dataset is split into k
parts (called folds). For each fold, the model is trained on k-1
folds and tested on the remaining folds. The accuracy is taken as the average over the folds.
We will show a five-fold cross-validation on the Iris dataset, using SVC again:
>>> from sklearn.cross_validation import cross_val_score >>> clf = SVC() >>> scores = cross_val_score(clf, iris.data, iris.target, cv=5) >>> scores array([ 0.96666667, 1. , 0.96666667, 0.96666667, 1. ]) >>> scores.mean() 0.98000000000000009
There are various strategies implemented by different classes to split the dataset for cross-validation: KFold
, StratifiedKFold
, LeaveOneOut
, LeavePOut
, LeaveOneLabelOut
, LeavePLableOut
, ShuffleSplit
, StratifiedShuffleSplit
, and PredefinedSplit
.
Model verification is an important step and it is necessary for the development of robust machine learning solutions.
- Excel 2007函數與公式自學寶典
- Managing Mission:Critical Domains and DNS
- 基于多目標決策的數據挖掘方法評估與應用
- 控制系統(tǒng)計算機仿真
- CompTIA Linux+ Certification Guide
- Troubleshooting OpenVPN
- Mastering Game Development with Unreal Engine 4(Second Edition)
- Salesforce Advanced Administrator Certification Guide
- 單片機技術項目化原理與實訓
- Building Google Cloud Platform Solutions
- 貫通開源Web圖形與報表技術全集
- Microsoft Dynamics CRM 2013 Marketing Automation
- Arduino創(chuàng)意機器人入門:基于ArduBlock(第2版)
- Building Analytics Teams
- 工業(yè)機器人編程指令詳解