- Python計算機視覺與深度學習實戰
- 郭卡 戴亮編著
- 1492字
- 2021-08-27 20:19:04
1.2 數據集
人工智能的核心在于數據支持,近幾年人工智能技術的快速發展與大數據技術的發展密切相關,大數據技術可以通過數據采集、分析及挖掘等方式,從海量復雜數據中快速提取出有價值的信息,為機器學習算法提供牢固的基礎。
在機器學習任務中,數據集有三大功能:訓練、驗證和測試。
- 訓練最好理解,是擬合模型的過程,模型會通過分析數據、調節內部參數從而得到最優的模型效果。
- 驗證即驗證模型效果,效果可以指導我們調整模型中的超參數(在開始訓練之前設置參數,而不是通過訓練得到參數),通常會使用少量未參與訓練的數據對模型進行驗證,在訓練的間隙中進行。
- 測試的作用是檢查模型是否具有泛化能力(泛化能力是指模型對訓練集之外的數據集是否也有很好的擬合能力)。通常會在模型訓練完畢之后,選用較多訓練集以外的數據進行測試。
所以在機器學習(尤其是深度學習)任務開始前,需要收集大量高質量的數據,對于個人開發者來說,數據只能來源于開源的數據集和自己編寫爬蟲程序采集到的數據集,收集數據是一個費時費力的過程。
為了方便初學者學習以及進行小規模的算法測試,sklearn提供了不少小型的標準數據集和一些規模略大的真實數據集。除這些數據集之外,sklearn還能夠按照一定規則自己生成數據集。3種類型的數據集分別通過load***
、fetch***
和make***
這3種函數形式獲取,下面將對這幾個接口做簡單介紹。
1.2.1 自帶的小型數據集
sklearn中最常用的數據集有3個:load_iris
、load_boston
和load_digits
。
直接從sklearn.datasets
中導入load_iris
,得到的數據是字典形式,可以通過字典中的鍵值選擇數據的各項屬性。
load_iris
是加載鳶尾花數據集的函數,該數據集包含了150條鳶尾花數據,其中包含的鳶尾花數據(在機器學習中,這種可以直接用于建模的數據叫作特征)有4種:
- 鳶尾花的花瓣長度(cm);
- 鳶尾花的花瓣寬度(cm);
- 鳶尾花的花萼長度(cm);
- 鳶尾花的花萼寬度(cm)。
標簽是鳶尾花的種類,3個種類分別用0
、1
和2
表示。下面是load_iris
的使用方法:
>>> d = load_iris()
>>> d.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
>>> # 鳶尾花的類別名
>>> d['target_names']
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
>>> # 特征名稱
>>> d['feature_names']
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
>>> d['data'].shape
(150, 4)
>>> set(list(d['target']))
{0, 1, 2}
在上述代碼中,通過load_iris
函數取出了鳶尾花數據并將其賦值給d
,通過keys
方法查看數據集中各個項目的名稱,如鳶尾花的類別名(target_names
)、特征名(feature_names
)、數據(data
)與標簽(target
)等。
load_boston
是關于波士頓房屋特征與房價之間關系的數據集,包含13個房屋特征,是一個進行入門回歸訓練的好例子。下面是load_boston
的使用方法:
>>> data = load_boston()
>>> # 房屋特征名稱
>>> data['feature_names']
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')
>>> data['data'].shape
(506, 13)
從上述代碼中可以看到,load_boston
中共有506個樣本,每條數據中包含了房屋和房屋周邊的13個重要信息,如城市犯罪率、環保指標、周邊老房子的比例、是否臨河等。
load_digits
是一個比MNIST更小的手寫數字圖片數據集,里面的圖片尺寸是8像素×8像素(后面將省略單位),通過如下代碼可以查看手寫數字圖片:
>>> g = sklearn.datasets.load_digits()
>>> plt.imshow(g['data'][0].reshape(8,8),cmap='gray')
<matplotlib.image.AxesImage object at 0x7f07e42ddeb8>
>>> plt.show()
輸出圖片如圖1-5所示,因為是8×8的圖片,所以看起來不是很清晰。

圖 1-5 手寫數字
1.2.2 在線下載的數據集
Fetch
系列函數用于獲取較大規模的數據集,這些數據集會自動從網上下載,得到的數據格式與load***
一樣,是字典形式。我們可以自定義下載目錄,同時可以選擇單獨下載訓練集或者測試集,常用的數據集如下。
- 人臉數據集:
fetch_olivetti_faces
和fetch_lfw_people
。 - 文本分類數據集:
fetch_20newsgroups
。 - 房價回歸數據集:
fetch_california_housing
。
1.2.3 計算機生成的數據集
用sklearn生成的數據集可以用來測試一些基礎的模型功能,比如多分類數據集、聚類數據集以及高斯分布數據集等。還有一些特殊形狀的數據集,比如make_circles
和make_moons
等,示例如下:
>>> circle = make_circles()[0]
>>> # 創建子圖
>>> plt.subplot(121)
<matplotlib.axes._subplots.AxesSubplot object at 0x000000001719BE80>
>>> # 繪制散點圖
>>> plt.scatter(circle[:,0],circle[:,1])
<matplotlib.collections.PathCollection object at 0x000000002081D828>
>>> moon = make_moons()[0]
>>> plt.subplot(122)
<matplotlib.axes._subplots.AxesSubplot object at 0x000000002081D048>
>>> plt.scatter(moon[:,0],moon[:,1])
<matplotlib.collections.PathCollection object at 0x0000000017171D30>
>>> plt.show()
上述代碼的作用是通過make_circles
和make_moons
函數生成兩組坐標點數據,并使用plt.scatter
函數將生成的坐標點繪制成散點圖。生成的散點圖如圖1-6所示,其他數據集詳情請參考sklearn官網。

圖 1-6 生成的散點圖