-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsklearn_CrossValidation.html
909 lines (676 loc) · 206 KB
/
sklearn_CrossValidation.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
<!DOCTYPE html>
<html>
<head><meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>sklearn_CrossValidation</title><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.1.10/require.min.js"></script>
<link rel="stylesheet" type="text/css" href="/static/css/md_notebook.css" />
<!-- Load mathjax -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML-full,Safe"> </script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
init_mathjax = function() {
if (window.MathJax) {
// MathJax loaded
MathJax.Hub.Config({
TeX: {
equationNumbers: {
autoNumber: "AMS",
useLabelIds: true
}
},
tex2jax: {
inlineMath: [ ['$','$'], ["\\(","\\)"] ],
displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
processEscapes: true,
processEnvironments: true
},
displayAlign: 'center',
CommonHTML: {
linebreaks: {
automatic: true
}
}
});
MathJax.Hub.Queue(["Typeset", MathJax.Hub]);
}
}
init_mathjax();
</script>
<!-- End of mathjax configuration --></head>
<body class="jp-Notebook" data-jp-theme-light="true" data-jp-theme-name="JupyterLab Light">
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>chapter 11 模型评估与优化</p>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h1 id="%E4%BD%BF%E7%94%A8%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E8%BF%9B%E8%A1%8C%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0">使用交叉验证进行模型评估<a class="anchor-link" href="#%E4%BD%BF%E7%94%A8%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E8%BF%9B%E8%A1%8C%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BC%B0">¶</a></h1>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>除了之前的 train_test_split 函数外,还有更自动化的拆分测试工具。就是交叉验证法(Cross Validation)。</p>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>K折叠交叉验证法(k-fold cross validation):</p>
<ul>
<li>把数据随机拆分成k份,对k-1份训练,使用1份测试。</li>
<li>循环k次,则每份都做过测试集。</li>
<li>对外输出5次测试集的打分平均数。</li>
</ul>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E9%9A%8F%E6%9C%BA%E6%8B%86%E5%88%86%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E6%B3%95-shuffle-split-cross-validation">随机拆分交叉验证法 shuffle-split cross validation<a class="anchor-link" href="#%E9%9A%8F%E6%9C%BA%E6%8B%86%E5%88%86%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E6%B3%95-shuffle-split-cross-validation">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [1]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 导入红酒数据集</span>
<span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">load_wine</span>
<span class="c1"># 导入交叉验证工具</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">cross_val_score</span>
<span class="c1"># 导入用于分类的支持向量机模型</span>
<span class="kn">from</span> <span class="nn">sklearn.svm</span> <span class="kn">import</span> <span class="n">SVC</span>
<span class="c1"># 载入数据</span>
<span class="n">wine</span> <span class="o">=</span> <span class="n">load_wine</span><span class="p">()</span>
<span class="c1"># 设置 SVC 的核函数为 linear</span>
<span class="n">svc</span> <span class="o">=</span> <span class="n">SVC</span><span class="p">(</span><span class="n">kernel</span><span class="o">=</span><span class="s1">'linear'</span><span class="p">)</span>
<span class="c1"># 使用交叉验证法对SVC进行评分</span>
<span class="n">scores</span><span class="o">=</span><span class="n">cross_val_score</span><span class="p">(</span><span class="n">svc</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="c1">#默认分成5份</span>
<span class="nb">print</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">Mean score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">scores</span><span class="o">.</span><span class="n">mean</span><span class="p">())</span> <span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>[0.88888889 0.94444444 0.97222222 1. 1. ]
Mean score:0.961
</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [2]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># sklearn对分类模型,默认使用的是 分层k交叉验证法。</span>
<span class="nb">print</span><span class="p">(</span><span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">)</span> <span class="c1">#可见酒分了3类。</span>
<span class="c1"># 所谓的分层,就是如果目标是80%男,20%女,则拆分数据后每一份中男女也是这个比例。</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E9%9A%8F%E6%9C%BA%E6%8B%86%E5%88%86%E6%B3%95(shuffle-split-corss-validation)">随机拆分法(shuffle-split corss-validation)<a class="anchor-link" href="#%E9%9A%8F%E6%9C%BA%E6%8B%86%E5%88%86%E6%B3%95(shuffle-split-corss-validation)">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [3]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 原理:先从数据中随机抽取一部分作为训练集,再从其余部分随机抽取一部分作为测试集,进行评分后再迭代,</span>
<span class="c1"># 直到达到我们希望的迭代次数。</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">ShuffleSplit</span>
<span class="c1"># 设置拆分为10份</span>
<span class="n">shuffle_split</span> <span class="o">=</span> <span class="n">ShuffleSplit</span><span class="p">(</span><span class="n">train_size</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">n_splits</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># 交叉验证</span>
<span class="n">scores</span><span class="o">=</span><span class="n">cross_val_score</span><span class="p">(</span><span class="n">svc</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">shuffle_split</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">Mean score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">scores</span><span class="o">.</span><span class="n">mean</span><span class="p">())</span> <span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>[0.91666667 0.97222222 0.94444444 0.97222222 0.97222222 0.97222222
0.91666667 0.91666667 0.97222222 1. ]
Mean score:0.956
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E6%8C%A8%E4%B8%AA%E5%84%BF%E8%AF%95%E8%AF%95%E6%B3%95(-leave-one-out-)">挨个儿试试法( leave-one-out )<a class="anchor-link" href="#%E6%8C%A8%E4%B8%AA%E5%84%BF%E8%AF%95%E8%AF%95%E6%B3%95(-leave-one-out-)">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [4]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 原理:每次取一个作为测试集,其他作为训练集。缺点是耗时,特别是对大数据集。</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">LeaveOneOut</span>
<span class="n">loo</span> <span class="o">=</span> <span class="n">LeaveOneOut</span><span class="p">()</span>
<span class="c1"># 交叉验证</span>
<span class="n">scores</span><span class="o">=</span><span class="n">cross_val_score</span><span class="p">(</span><span class="n">svc</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="n">loo</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">Mean score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">scores</span><span class="o">.</span><span class="n">mean</span><span class="p">())</span> <span class="p">)</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.
1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Mean score:0.955
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h1 id="%E4%BD%BF%E7%94%A8%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2%E4%BC%98%E5%8C%96%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0">使用网络搜索优化模型参数<a class="anchor-link" href="#%E4%BD%BF%E7%94%A8%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2%E4%BC%98%E5%8C%96%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0">¶</a></h1>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>除了逐个尝试看打分外,还可以使用网络搜索法一次性找到更优的参数设置。</p>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E7%AE%80%E5%8D%95%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2">简单网络搜索<a class="anchor-link" href="#%E7%AE%80%E5%8D%95%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [1]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 目的: 尝试2个参数的组合 max_iter=100,1000,5000,10000, alpha=10,1,0.1,0.01;</span>
<span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">load_wine</span>
<span class="n">wine</span> <span class="o">=</span> <span class="n">load_wine</span><span class="p">()</span>
<span class="c1"># 导入套索回归模型</span>
<span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">Lasso</span>
<span class="c1"># 导入数据集拆分工具</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="c1"># 将数据集拆分为训练集与测试集</span>
<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">wine</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">38</span><span class="p">)</span>
<span class="c1"># 设置初试分数0</span>
<span class="n">best_score</span><span class="o">=</span><span class="mi">0</span>
<span class="c1"># 设置 alpha 参数4个遍历</span>
<span class="k">for</span> <span class="n">alpha</span> <span class="ow">in</span> <span class="p">[</span><span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">]:</span>
<span class="c1">#设置最大迭代次数, 4个遍历</span>
<span class="k">for</span> <span class="n">max_iter</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1000</span><span class="p">,</span> <span class="mi">5000</span><span class="p">,</span> <span class="mi">10000</span><span class="p">]:</span>
<span class="n">lasso</span><span class="o">=</span><span class="n">Lasso</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="n">max_iter</span><span class="p">)</span>
<span class="n">lasso</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">score</span><span class="o">=</span><span class="n">lasso</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
<span class="c1"># print(alpha, max_iter, score)</span>
<span class="c1"># 最高分数</span>
<span class="k">if</span> <span class="n">score</span> <span class="o">></span> <span class="n">best_score</span><span class="p">:</span>
<span class="n">best_score</span><span class="o">=</span><span class="n">score</span>
<span class="n">best_parameters</span><span class="o">=</span><span class="p">{</span><span class="s2">"alpha"</span><span class="p">:</span><span class="n">alpha</span><span class="p">,</span> <span class="s2">"max_iter"</span><span class="p">:</span><span class="n">max_iter</span><span class="p">}</span>
<span class="c1"># 输出</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_score</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_parameters:</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_parameters</span><span class="p">))</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>best_score:0.889
best_parameters:{'alpha': 0.01, 'max_iter': 100}
</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [2]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 上述做法是有缺陷的,因为16次评分使用的是同一个训练集和测试集。不能反映出新数据及的情况。</span>
<span class="c1"># 比如换一个拆分方法,打分也会变 random_state=0</span>
<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">wine</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">wine</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">best_score</span><span class="o">=</span><span class="mi">0</span>
<span class="k">for</span> <span class="n">alpha</span> <span class="ow">in</span> <span class="p">[</span><span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">]:</span>
<span class="k">for</span> <span class="n">max_iter</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1000</span><span class="p">,</span> <span class="mi">5000</span><span class="p">,</span> <span class="mi">10000</span><span class="p">]:</span>
<span class="n">lasso</span><span class="o">=</span><span class="n">Lasso</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="n">max_iter</span><span class="p">)</span>
<span class="n">lasso</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">score</span><span class="o">=</span><span class="n">lasso</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
<span class="k">if</span> <span class="n">score</span> <span class="o">></span> <span class="n">best_score</span><span class="p">:</span>
<span class="n">best_score</span><span class="o">=</span><span class="n">score</span>
<span class="n">best_parameters</span><span class="o">=</span><span class="p">{</span><span class="s2">"alpha"</span><span class="p">:</span><span class="n">alpha</span><span class="p">,</span> <span class="s2">"max_iter"</span><span class="p">:</span><span class="n">max_iter</span><span class="p">}</span>
<span class="c1"># 输出</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_score</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_parameters:</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_parameters</span><span class="p">))</span>
<span class="c1"># 改变拆分方法后,最佳参数变为 alpha=0.1, 最佳打分也下降了</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>best_score:0.830
best_parameters:{'alpha': 0.1, 'max_iter': 100}
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E4%B8%8E%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E7%BB%93%E5%90%88%E7%9A%84%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2">与交叉验证结合的网络搜索<a class="anchor-link" href="#%E4%B8%8E%E4%BA%A4%E5%8F%89%E9%AA%8C%E8%AF%81%E7%BB%93%E5%90%88%E7%9A%84%E7%BD%91%E7%BB%9C%E6%90%9C%E7%B4%A2">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [3]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">cross_val_score</span>
<span class="n">best_score</span><span class="o">=</span><span class="mi">0</span>
<span class="k">for</span> <span class="n">alpha</span> <span class="ow">in</span> <span class="p">[</span><span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">]:</span>
<span class="k">for</span> <span class="n">max_iter</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">100</span><span class="p">,</span> <span class="mi">1000</span><span class="p">,</span> <span class="mi">5000</span><span class="p">,</span> <span class="mi">10000</span><span class="p">]:</span>
<span class="n">lasso</span><span class="o">=</span><span class="n">Lasso</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="n">max_iter</span><span class="p">)</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">cross_val_score</span><span class="p">(</span><span class="n">lasso</span><span class="p">,</span> <span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
<span class="n">score</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
<span class="k">if</span> <span class="n">score</span> <span class="o">></span> <span class="n">best_score</span><span class="p">:</span>
<span class="n">best_score</span><span class="o">=</span><span class="n">score</span>
<span class="n">best_parameters</span><span class="o">=</span><span class="p">{</span><span class="s2">"alpha"</span><span class="p">:</span><span class="n">alpha</span><span class="p">,</span> <span class="s2">"max_iter"</span><span class="p">:</span><span class="n">max_iter</span><span class="p">}</span>
<span class="c1"># 输出</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_score</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best_parameters:</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">best_parameters</span><span class="p">))</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>best_score:0.865
best_parameters:{'alpha': 0.01, 'max_iter': 100}
</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [4]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 再试试验证集上,使用最优参数</span>
<span class="n">lasso</span><span class="o">=</span><span class="n">Lasso</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">max_iter</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"test score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">lasso</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)))</span>
<span class="c1"># 这个打分有点低。因为lasso会舍弃特征,而我们的特征本来就不多,说明舍弃的列不那么冗余。</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>test score:0.819
</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [5]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 使用内置的 GridSearchCV 简化上述循环</span>
<span class="c1">#导入网络搜索工具</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">GridSearchCV</span>
<span class="n">params</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'alpha'</span><span class="p">:[</span><span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">],</span> <span class="s1">'max_iter'</span><span class="p">:[</span><span class="mi">100</span><span class="p">,</span><span class="mi">1000</span><span class="p">,</span><span class="mi">5000</span><span class="p">,</span><span class="mi">10000</span><span class="p">]}</span>
<span class="c1"># 定义模型和参数</span>
<span class="n">grid_search</span> <span class="o">=</span> <span class="n">GridSearchCV</span><span class="p">(</span><span class="n">lasso</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">cv</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
<span class="n">grid_search</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="c1"># 打分</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"test score:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">grid_search</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)))</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"best params:"</span><span class="p">,</span> <span class="n">grid_search</span><span class="o">.</span><span class="n">best_params_</span><span class="p">)</span>
<span class="c1"># 最佳条件和上面的for循环一样。</span>
<span class="c1"># 注意: grid_search.best_score_ 中存储的是交叉验证的最高分,而不是在测试集上的得分。</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">best_score_:</span><span class="si">{:0.3f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">grid_search</span><span class="o">.</span><span class="n">best_score_</span><span class="p">))</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>test score:0.819
best params: {'alpha': 0.01, 'max_iter': 100}
best_score_:0.865
</pre>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h1 id="%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E7%9A%84%E5%8F%AF%E4%BF%A1%E5%BA%A6%E8%AF%84%E4%BC%B0">分类模型的可信度评估<a class="anchor-link" href="#%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E7%9A%84%E5%8F%AF%E4%BF%A1%E5%BA%A6%E8%AF%84%E4%BC%B0">¶</a></h1>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E4%B8%AD%E7%9A%84%E9%A2%84%E6%B5%8B%E5%87%86%E7%A1%AE%E7%8E%87">分类模型中的预测准确率<a class="anchor-link" href="#%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E4%B8%AD%E7%9A%84%E9%A2%84%E6%B5%8B%E5%87%86%E7%A1%AE%E7%8E%87">¶</a></h2>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [1]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">make_blobs</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span> <span class="n">make_blobs</span><span class="p">(</span><span class="n">n_samples</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">centers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">cluster_std</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">X</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">cool</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'k'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="c1"># 假设深红色 1,浅色0</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedImage jp-OutputArea-output ">
<img src="
"
>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [2]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 导入高斯贝叶斯模型</span>
<span class="kn">from</span> <span class="nn">sklearn.naive_bayes</span> <span class="kn">import</span> <span class="n">GaussianNB</span>
<span class="c1"># 数据集拆分</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">68</span><span class="p">)</span>
<span class="c1"># 训练高斯贝叶斯模型</span>
<span class="n">gnb</span><span class="o">=</span><span class="n">GaussianNB</span><span class="p">()</span>
<span class="n">gnb</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">predict_prob</span> <span class="o">=</span> <span class="n">gnb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">predict_prob</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1">#(50, 2)</span>
<span class="n">predict_prob</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
<span class="c1"># 第一行是归为2个分类的概率 98.8% vs 1.2%</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>(50, 2)
</pre>
</div>
</div>
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt">Out[2]:</div>
<div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain">
<pre>array([[0.98849996, 0.01150004],
[0.0495985 , 0.9504015 ],
[0.01648034, 0.98351966]])</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [3]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 可视化分类过程中的表现</span>
<span class="n">x_min</span><span class="p">,</span> <span class="n">x_max</span><span class="o">=</span> <span class="n">X</span><span class="p">[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">X</span><span class="p">[:,</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">+</span><span class="mf">0.5</span>
<span class="n">y_min</span><span class="p">,</span> <span class="n">y_max</span><span class="o">=</span> <span class="n">X</span><span class="p">[:,</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">X</span><span class="p">[:,</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">+</span><span class="mf">0.5</span>
<span class="n">xx</span><span class="p">,</span> <span class="n">yy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">x_min</span><span class="p">,</span> <span class="n">x_max</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">),</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">y_min</span><span class="p">,</span> <span class="n">y_max</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">))</span>
<span class="n">Z</span><span class="o">=</span><span class="n">gnb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">c_</span><span class="p">[</span><span class="n">xx</span><span class="o">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">yy</span><span class="o">.</span><span class="n">ravel</span><span class="p">()])[:,</span><span class="mi">1</span><span class="p">]</span>
<span class="n">Z</span><span class="o">=</span><span class="n">Z</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">xx</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1">#绘制等高线</span>
<span class="n">plt</span><span class="o">.</span><span class="n">contourf</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">summer</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_train</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">X_train</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">cool</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'k'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_test</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">X_test</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">y_test</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">cool</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'k'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">)</span>
<span class="c1"># 设置坐标轴范围 </span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">(</span><span class="n">xx</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">xx</span><span class="o">.</span><span class="n">max</span><span class="p">())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">(</span><span class="n">yy</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">yy</span><span class="o">.</span><span class="n">max</span><span class="p">())</span>
<span class="c1"># 设置横纵轴的单位</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"GaussianNB predict_proba"</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="c1"># 背景颜色代表概率值。半透明的点是测试集。</span>
<span class="c1"># 并不是每个分类算法都有 predict_proba 属性。</span>
<span class="c1"># 不过我们还可以使用另一种方法检查分类的可信度,就是决定系数 decision_function.</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedImage jp-OutputArea-output ">
<img src="
"
>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E4%B8%AD%E7%9A%84%E5%86%B3%E5%AE%9A%E7%B3%BB%E6%95%B0">分类模型中的决定系数<a class="anchor-link" href="#%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E4%B8%AD%E7%9A%84%E5%86%B3%E5%AE%9A%E7%B3%BB%E6%95%B0">¶</a></h2>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>决定系数只返回一个值,正数代表属于分类1,负数代表属于分类2.</p>
<p>由于高斯朴素贝叶斯没有 decision_function,我们换成 支持向量机SVC 算法来进行建模。</p>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [4]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="kn">from</span> <span class="nn">sklearn.svm</span> <span class="kn">import</span> <span class="n">SVC</span>
<span class="n">svc</span><span class="o">=</span><span class="n">SVC</span><span class="p">()</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">dec_func</span><span class="o">=</span> <span class="n">svc</span><span class="o">.</span><span class="n">decision_function</span><span class="p">(</span><span class="n">X_test</span><span class="p">)</span>
<span class="c1"># 打印决定系数的前5个</span>
<span class="nb">print</span><span class="p">(</span><span class="n">dec_func</span><span class="p">[:</span><span class="mi">5</span><span class="p">])</span>
<span class="c1"># 负数一类,正数一类。</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain">
<pre>[-1.36071347 1.53694862 1.78825594 -0.96133081 1.81826853]
</pre>
</div>
</div>
</div>
</div>
</div><div class="jp-Cell jp-CodeCell jp-Notebook-cell ">
<div class="jp-Cell-inputWrapper">
<div class="jp-InputArea jp-Cell-inputArea">
<div class="jp-InputPrompt jp-InputArea-prompt">In [5]:</div>
<div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
<div class="CodeMirror cm-s-jupyter">
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># 可视化工作原理</span>
<span class="n">Z</span><span class="o">=</span><span class="n">svc</span><span class="o">.</span><span class="n">decision_function</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">c_</span><span class="p">[</span><span class="n">xx</span><span class="o">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">yy</span><span class="o">.</span><span class="n">ravel</span><span class="p">()])</span>
<span class="n">Z</span><span class="o">=</span><span class="n">Z</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">xx</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># 绘制等高线</span>
<span class="n">plt</span><span class="o">.</span><span class="n">contourf</span><span class="p">(</span><span class="n">xx</span><span class="p">,</span> <span class="n">yy</span><span class="p">,</span> <span class="n">Z</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">summer</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="c1"># 绘制散点图</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_train</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">X_train</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">cool</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'k'</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">X_test</span><span class="p">[:,</span><span class="mi">0</span><span class="p">],</span> <span class="n">X_test</span><span class="p">[:,</span><span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">y_test</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">cool</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'k'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">(</span><span class="n">xx</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">xx</span><span class="o">.</span><span class="n">max</span><span class="p">())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylim</span><span class="p">(</span><span class="n">yy</span><span class="o">.</span><span class="n">min</span><span class="p">(),</span> <span class="n">yy</span><span class="o">.</span><span class="n">max</span><span class="p">())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"SVC decision_function"</span><span class="p">)</span>
<span class="c1"># 设置横纵轴的单位</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(())</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="c1"># 和上图类似,有一个斜对角线分界线。</span>
<span class="c1"># 但是深背景色的核心位置不同。处于区域中心的是高度可信的,而边缘区域和交界区域的,则是“模棱两可”区域。</span>
</pre></div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-outputWrapper">
<div class="jp-OutputArea jp-Cell-outputArea">
<div class="jp-OutputArea-child">
<div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
<div class="jp-RenderedImage jp-OutputArea-output ">
<img src="
"
>
</div>
</div>
</div>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>本例汇总使用的是二元分类任务,但是 predict_proba 和 decision_function 同样适用于多元分类任务。</p>
<ul>
<li>感兴趣的读者可以调整 make_blobs 的 centers 参数进行试验。</li>
</ul>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<h2 id="%E5%85%B6%E5%AE%83%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87">其它模型评价指标<a class="anchor-link" href="#%E5%85%B6%E5%AE%83%E6%A8%A1%E5%9E%8B%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87">¶</a></h2>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>之前一直使用的 .score(),</p>
<ul>
<li>对于分类模型,给出的是模型分类的准确率(accuracy);</li>
<li>对于回归模型,给出的是回归分析中的R^2分数,翻译为 可决系数,或 拟合优度。(决定系数)<ul>
<li>R^2 的计算方法是用 “回归平方和”(explain sum of squares, ESS) 除以“总变差”(total sum of squares, TSS)</li>
</ul>
</li>
</ul>
<p>R^2 = 1 - 求和(y-y_hat)^2/求和(y-y_bar)^2</p>
</div>
</div>
<div class="jp-Cell-inputWrapper"><div class="jp-InputPrompt jp-InputArea-prompt">
</div><div class="jp-RenderedHTMLCommon jp-RenderedMarkdown jp-MarkdownOutput " data-mime-type="text/markdown">
<p>其它指标</p>
<ul>
<li>精度 Precision</li>
<li>召回率 Recall</li>
<li>f1分数 f1-score</li>
<li>ROC(Receiver Operating Characteristic Curve)</li>
<li>AUC(Area Under Curve)</li>
</ul>
<p>这些指标经常和 网格搜索算法 配合使用,选择合适的模型和参数。</p>
<ul>
<li>如果使用 GridSearchCV 类,需要改变评分方法,只需要修改 scoring 参数即可。</li>
<li>随机森林,修改评分为 roc_auc:<ul>
<li>grid=GridSearchCV(RandomForestClassifier(), param_grid=param_grid, scoring='roc_auc')</li>
</ul>
</li>
</ul>
</div>
</div>
</body>
</html>