机器学习基础 互动版

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器

学习与预测

对于数码数据集而言,我们的目标是根据图像矩阵预测其代表的数字,我们给出10个数字的手写图像矩阵,并通过监督学习完成图像识别。

在scikit-learn中,分类主要使用Python对象中的2个方法,fit(X, y)和predict(T)。前者是用于训练,后者用于分类预测。

接下来,将以sklearn.svm.SVC作为分类器为例进行支持向量分类,分类器在创建时需要传入与此模型有关的的参数,这样在训练预测时就可以看作一个黑盒。我们定义一个名为clf的SVC对象,然后进行训练。选取除最后4组外的数据作为训练集,最后4组数据作为测试集。

from sklearn import svm
from sklearn import datasets
digits = datasets.load_digits()
clf = svm.SVC(gamma=0.001, C=100.)                #使用支持向量机进行分类
clf.fit(digits.data[:-4], digits.target[:-4])    #将除最后4组的数据输入进行训练
clf.predict(digits.data[-4:])                    #预测最后4组的数据

如果没错,此时会输出只有4个数字的数组array([0,8,9,8]),说明最后3组图像表示的数字是0,8,9,8。我们将最后4组图像矩阵所表示的图像绘制出来。