diff --git a/02_activities/assignments/assignment_1.ipynb b/02_activities/assignments/assignment_1.ipynb index 28d4df017..0e659efdb 100644 --- a/02_activities/assignments/assignment_1.ipynb +++ b/02_activities/assignments/assignment_1.ipynb @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "4a3485d6-ba58-4660-a983-5680821c5719", "metadata": {}, "outputs": [], @@ -56,10 +56,1224 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "a431d282-f9ca-4d5d-8912-71ffc9d8ea19", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.microsoft.datawrangler.viewer.v0+json": { + "columns": [ + { + "name": "index", + "rawType": "int64", + "type": "integer" + }, + { + "name": "alcohol", + "rawType": "float64", + "type": "float" + }, + { + "name": "malic_acid", + "rawType": "float64", + "type": "float" + }, + { + "name": "ash", + "rawType": "float64", + "type": "float" + }, + { + "name": "alcalinity_of_ash", + "rawType": "float64", + "type": "float" + }, + { + "name": "magnesium", + "rawType": "float64", + "type": "float" + }, + { + "name": "total_phenols", + "rawType": "float64", + "type": "float" + }, + { + "name": "flavanoids", + "rawType": "float64", + "type": "float" + }, + { + "name": "nonflavanoid_phenols", + "rawType": "float64", + "type": "float" + }, + { + "name": "proanthocyanins", + "rawType": "float64", + "type": "float" + }, + { + "name": "color_intensity", + "rawType": "float64", + "type": "float" + }, + { + "name": "hue", + "rawType": "float64", + "type": "float" + }, + { + "name": "od280/od315_of_diluted_wines", + "rawType": "float64", + "type": "float" + }, + { + "name": "proline", + "rawType": "float64", + "type": "float" + }, + { + "name": "class", + "rawType": "int64", + "type": "integer" + } + ], + "ref": "40b2e746-2788-463c-9478-db19b5cc86f2", + "rows": [ + [ + "0", + "14.23", + "1.71", + "2.43", + "15.6", + "127.0", + "2.8", + "3.06", + "0.28", + "2.29", + "5.64", + "1.04", + "3.92", + "1065.0", + "0" + ], + [ + "1", + "13.2", + "1.78", + "2.14", + "11.2", + "100.0", + "2.65", + "2.76", + "0.26", + "1.28", + "4.38", + "1.05", + "3.4", + "1050.0", + "0" + ], + [ + "2", + "13.16", + "2.36", + "2.67", + "18.6", + "101.0", + "2.8", + "3.24", + "0.3", + "2.81", + "5.68", + "1.03", + "3.17", + "1185.0", + "0" + ], + [ + "3", + "14.37", + "1.95", + "2.5", + "16.8", + "113.0", + "3.85", + "3.49", + "0.24", + "2.18", + "7.8", + "0.86", + "3.45", + "1480.0", + "0" + ], + [ + "4", + "13.24", + "2.59", + "2.87", + "21.0", + "118.0", + "2.8", + "2.69", + "0.39", + "1.82", + "4.32", + "1.04", + "2.93", + "735.0", + "0" + ], + [ + "5", + "14.2", + "1.76", + "2.45", + "15.2", + "112.0", + "3.27", + "3.39", + "0.34", + "1.97", + "6.75", + "1.05", + "2.85", + "1450.0", + "0" + ], + [ + "6", + "14.39", + "1.87", + "2.45", + "14.6", + "96.0", + "2.5", + "2.52", + "0.3", + "1.98", + "5.25", + "1.02", + "3.58", + "1290.0", + "0" + ], + [ + "7", + "14.06", + "2.15", + "2.61", + "17.6", + "121.0", + "2.6", + "2.51", + "0.31", + "1.25", + "5.05", + "1.06", + "3.58", + "1295.0", + "0" + ], + [ + "8", + "14.83", + "1.64", + "2.17", + "14.0", + "97.0", + "2.8", + "2.98", + "0.29", + "1.98", + "5.2", + "1.08", + "2.85", + "1045.0", + "0" + ], + [ + "9", + "13.86", + "1.35", + "2.27", + "16.0", + "98.0", + "2.98", + "3.15", + "0.22", + "1.85", + "7.22", + "1.01", + "3.55", + "1045.0", + "0" + ], + [ + "10", + "14.1", + "2.16", + "2.3", + "18.0", + "105.0", + "2.95", + "3.32", + "0.22", + "2.38", + "5.75", + "1.25", + "3.17", + "1510.0", + "0" + ], + [ + "11", + "14.12", + "1.48", + "2.32", + "16.8", + "95.0", + "2.2", + "2.43", + "0.26", + "1.57", + "5.0", + "1.17", + "2.82", + "1280.0", + "0" + ], + [ + "12", + "13.75", + "1.73", + "2.41", + "16.0", + "89.0", + "2.6", + "2.76", + "0.29", + "1.81", + "5.6", + "1.15", + "2.9", + "1320.0", + "0" + ], + [ + "13", + "14.75", + "1.73", + "2.39", + "11.4", + "91.0", + "3.1", + "3.69", + "0.43", + "2.81", + "5.4", + "1.25", + "2.73", + "1150.0", + "0" + ], + [ + "14", + "14.38", + "1.87", + "2.38", + "12.0", + "102.0", + "3.3", + "3.64", + "0.29", + "2.96", + "7.5", + "1.2", + "3.0", + "1547.0", + "0" + ], + [ + "15", + "13.63", + "1.81", + "2.7", + "17.2", + "112.0", + "2.85", + "2.91", + "0.3", + "1.46", + "7.3", + "1.28", + "2.88", + "1310.0", + "0" + ], + [ + "16", + "14.3", + "1.92", + "2.72", + "20.0", + "120.0", + "2.8", + "3.14", + "0.33", + "1.97", + "6.2", + "1.07", + "2.65", + "1280.0", + "0" + ], + [ + "17", + "13.83", + "1.57", + "2.62", + "20.0", + "115.0", + "2.95", + "3.4", + "0.4", + "1.72", + "6.6", + "1.13", + "2.57", + "1130.0", + "0" + ], + [ + "18", + "14.19", + "1.59", + "2.48", + "16.5", + "108.0", + "3.3", + "3.93", + "0.32", + "1.86", + "8.7", + "1.23", + "2.82", + "1680.0", + "0" + ], + [ + "19", + "13.64", + "3.1", + "2.56", + "15.2", + "116.0", + "2.7", + "3.03", + "0.17", + "1.66", + "5.1", + "0.96", + "3.36", + "845.0", + "0" + ], + [ + "20", + "14.06", + "1.63", + "2.28", + "16.0", + "126.0", + "3.0", + "3.17", + "0.24", + "2.1", + "5.65", + "1.09", + "3.71", + "780.0", + "0" + ], + [ + "21", + "12.93", + "3.8", + "2.65", + "18.6", + "102.0", + "2.41", + "2.41", + "0.25", + "1.98", + "4.5", + "1.03", + "3.52", + "770.0", + "0" + ], + [ + "22", + "13.71", + "1.86", + "2.36", + "16.6", + "101.0", + "2.61", + "2.88", + "0.27", + "1.69", + "3.8", + "1.11", + "4.0", + "1035.0", + "0" + ], + [ + "23", + "12.85", + "1.6", + "2.52", + "17.8", + "95.0", + "2.48", + "2.37", + "0.26", + "1.46", + "3.93", + "1.09", + "3.63", + "1015.0", + "0" + ], + [ + "24", + "13.5", + "1.81", + "2.61", + "20.0", + "96.0", + "2.53", + "2.61", + "0.28", + "1.66", + "3.52", + "1.12", + "3.82", + "845.0", + "0" + ], + [ + "25", + "13.05", + "2.05", + "3.22", + "25.0", + "124.0", + "2.63", + "2.68", + "0.47", + "1.92", + "3.58", + "1.13", + "3.2", + "830.0", + "0" + ], + [ + "26", + "13.39", + "1.77", + "2.62", + "16.1", + "93.0", + "2.85", + "2.94", + "0.34", + "1.45", + "4.8", + "0.92", + "3.22", + "1195.0", + "0" + ], + [ + "27", + "13.3", + "1.72", + "2.14", + "17.0", + "94.0", + "2.4", + "2.19", + "0.27", + "1.35", + "3.95", + "1.02", + "2.77", + "1285.0", + "0" + ], + [ + "28", + "13.87", + "1.9", + "2.8", + "19.4", + "107.0", + "2.95", + "2.97", + "0.37", + "1.76", + "4.5", + "1.25", + "3.4", + "915.0", + "0" + ], + [ + "29", + "14.02", + "1.68", + "2.21", + "16.0", + "96.0", + "2.65", + "2.33", + "0.26", + "1.98", + "4.7", + "1.04", + "3.59", + "1035.0", + "0" + ], + [ + "30", + "13.73", + "1.5", + "2.7", + "22.5", + "101.0", + "3.0", + "3.25", + "0.29", + "2.38", + "5.7", + "1.19", + "2.71", + "1285.0", + "0" + ], + [ + "31", + "13.58", + "1.66", + "2.36", + "19.1", + "106.0", + "2.86", + "3.19", + "0.22", + "1.95", + "6.9", + "1.09", + "2.88", + "1515.0", + "0" + ], + [ + "32", + "13.68", + "1.83", + "2.36", + "17.2", + "104.0", + "2.42", + "2.69", + "0.42", + "1.97", + "3.84", + "1.23", + "2.87", + "990.0", + "0" + ], + [ + "33", + "13.76", + "1.53", + "2.7", + "19.5", + "132.0", + "2.95", + "2.74", + "0.5", + "1.35", + "5.4", + "1.25", + "3.0", + "1235.0", + "0" + ], + [ + "34", + "13.51", + "1.8", + "2.65", + "19.0", + "110.0", + "2.35", + "2.53", + "0.29", + "1.54", + "4.2", + "1.1", + "2.87", + "1095.0", + "0" + ], + [ + "35", + "13.48", + "1.81", + "2.41", + "20.5", + "100.0", + "2.7", + "2.98", + "0.26", + "1.86", + "5.1", + "1.04", + "3.47", + "920.0", + "0" + ], + [ + "36", + "13.28", + "1.64", + "2.84", + "15.5", + "110.0", + "2.6", + "2.68", + "0.34", + "1.36", + "4.6", + "1.09", + "2.78", + "880.0", + "0" + ], + [ + "37", + "13.05", + "1.65", + "2.55", + "18.0", + "98.0", + "2.45", + "2.43", + "0.29", + "1.44", + "4.25", + "1.12", + "2.51", + "1105.0", + "0" + ], + [ + "38", + "13.07", + "1.5", + "2.1", + "15.5", + "98.0", + "2.4", + "2.64", + "0.28", + "1.37", + "3.7", + "1.18", + "2.69", + "1020.0", + "0" + ], + [ + "39", + "14.22", + "3.99", + "2.51", + "13.2", + "128.0", + "3.0", + "3.04", + "0.2", + "2.08", + "5.1", + "0.89", + "3.53", + "760.0", + "0" + ], + [ + "40", + "13.56", + "1.71", + "2.31", + "16.2", + "117.0", + "3.15", + "3.29", + "0.34", + "2.34", + "6.13", + "0.95", + "3.38", + "795.0", + "0" + ], + [ + "41", + "13.41", + "3.84", + "2.12", + "18.8", + "90.0", + "2.45", + "2.68", + "0.27", + "1.48", + "4.28", + "0.91", + "3.0", + "1035.0", + "0" + ], + [ + "42", + "13.88", + "1.89", + "2.59", + "15.0", + "101.0", + "3.25", + "3.56", + "0.17", + "1.7", + "5.43", + "0.88", + "3.56", + "1095.0", + "0" + ], + [ + "43", + "13.24", + "3.98", + "2.29", + "17.5", + "103.0", + "2.64", + "2.63", + "0.32", + "1.66", + "4.36", + "0.82", + "3.0", + "680.0", + "0" + ], + [ + "44", + "13.05", + "1.77", + "2.1", + "17.0", + "107.0", + "3.0", + "3.0", + "0.28", + "2.03", + "5.04", + "0.88", + "3.35", + "885.0", + "0" + ], + [ + "45", + "14.21", + "4.04", + "2.44", + "18.9", + "111.0", + "2.85", + "2.65", + "0.3", + "1.25", + "5.24", + "0.87", + "3.33", + "1080.0", + "0" + ], + [ + "46", + "14.38", + "3.59", + "2.28", + "16.0", + "102.0", + "3.25", + "3.17", + "0.27", + "2.19", + "4.9", + "1.04", + "3.44", + "1065.0", + "0" + ], + [ + "47", + "13.9", + "1.68", + "2.12", + "16.0", + "101.0", + "3.1", + "3.39", + "0.21", + "2.14", + "6.1", + "0.91", + "3.33", + "985.0", + "0" + ], + [ + "48", + "14.1", + "2.02", + "2.4", + "18.8", + "103.0", + "2.75", + "2.92", + "0.32", + "2.38", + "6.2", + "1.07", + "2.75", + "1060.0", + "0" + ], + [ + "49", + "13.94", + "1.73", + "2.27", + "17.4", + "108.0", + "2.88", + "3.54", + "0.32", + "2.08", + "8.9", + "1.12", + "3.1", + "1260.0", + "0" + ] + ], + "shape": { + "columns": 14, + "rows": 178 + } + }, + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
alcoholmalic_acidashalcalinity_of_ashmagnesiumtotal_phenolsflavanoidsnonflavanoid_phenolsproanthocyaninscolor_intensityhueod280/od315_of_diluted_winesprolineclass
014.231.712.4315.6127.02.803.060.282.295.641.043.921065.00
113.201.782.1411.2100.02.652.760.261.284.381.053.401050.00
213.162.362.6718.6101.02.803.240.302.815.681.033.171185.00
314.371.952.5016.8113.03.853.490.242.187.800.863.451480.00
413.242.592.8721.0118.02.802.690.391.824.321.042.93735.00
.............................................
17313.715.652.4520.595.01.680.610.521.067.700.641.74740.02
17413.403.912.4823.0102.01.800.750.431.417.300.701.56750.02
17513.274.282.2620.0120.01.590.690.431.3510.200.591.56835.02
17613.172.592.3720.0120.01.650.680.531.469.300.601.62840.02
17714.134.102.7424.596.02.050.760.561.359.200.611.60560.02
\n", + "

178 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols \\\n", + "0 14.23 1.71 2.43 15.6 127.0 2.80 \n", + "1 13.20 1.78 2.14 11.2 100.0 2.65 \n", + "2 13.16 2.36 2.67 18.6 101.0 2.80 \n", + "3 14.37 1.95 2.50 16.8 113.0 3.85 \n", + "4 13.24 2.59 2.87 21.0 118.0 2.80 \n", + ".. ... ... ... ... ... ... \n", + "173 13.71 5.65 2.45 20.5 95.0 1.68 \n", + "174 13.40 3.91 2.48 23.0 102.0 1.80 \n", + "175 13.27 4.28 2.26 20.0 120.0 1.59 \n", + "176 13.17 2.59 2.37 20.0 120.0 1.65 \n", + "177 14.13 4.10 2.74 24.5 96.0 2.05 \n", + "\n", + " flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue \\\n", + "0 3.06 0.28 2.29 5.64 1.04 \n", + "1 2.76 0.26 1.28 4.38 1.05 \n", + "2 3.24 0.30 2.81 5.68 1.03 \n", + "3 3.49 0.24 2.18 7.80 0.86 \n", + "4 2.69 0.39 1.82 4.32 1.04 \n", + ".. ... ... ... ... ... \n", + "173 0.61 0.52 1.06 7.70 0.64 \n", + "174 0.75 0.43 1.41 7.30 0.70 \n", + "175 0.69 0.43 1.35 10.20 0.59 \n", + "176 0.68 0.53 1.46 9.30 0.60 \n", + "177 0.76 0.56 1.35 9.20 0.61 \n", + "\n", + " od280/od315_of_diluted_wines proline class \n", + "0 3.92 1065.0 0 \n", + "1 3.40 1050.0 0 \n", + "2 3.17 1185.0 0 \n", + "3 3.45 1480.0 0 \n", + "4 2.93 735.0 0 \n", + ".. ... ... ... \n", + "173 1.74 740.0 2 \n", + "174 1.56 750.0 2 \n", + "175 1.56 835.0 2 \n", + "176 1.62 840.0 2 \n", + "177 1.60 560.0 2 \n", + "\n", + "[178 rows x 14 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from sklearn.datasets import load_wine\n", "\n", @@ -91,12 +1305,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "56916892", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of observations: 178\n" + ] + } + ], "source": [ - "# Your answer here" + "num_observations = wine_df.shape[0]\n", + "print(f\"Number of observations: {num_observations}\")" ] }, { @@ -109,12 +1332,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "df0ef103", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of variables: 14\n" + ] + } + ], "source": [ - "# Your answer here" + "num_variables = wine_df.shape[1]\n", + "print(f\"Number of variables: {num_variables}\")" ] }, { @@ -127,12 +1359,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "47989426", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "int64\n", + "Unique levels: [np.int64(0), np.int64(1), np.int64(2)]\n", + "The type of class is a categorical variable representing different wine categories with 3 levels.\n" + ] + } + ], "source": [ - "# Your answer here" + "print(wine_df['class'].dtype) \n", + "print(\"Unique levels:\", sorted(wine_df['class'].unique()))\n", + "print(\"The type of class is a categorical variable representing different wine categories with 3 levels.\")" ] }, { @@ -146,12 +1390,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "bd7b0910", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of predictor variables: 13\n" + ] + } + ], "source": [ - "# Your answer here" + "num_predictors = wine_df.shape[1] - 1\n", + "print(f\"Number of predictor variables: {num_predictors}\")" ] }, { @@ -175,10 +1428,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "cc899b59", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " alcohol malic_acid ash alcalinity_of_ash magnesium \\\n", + "0 1.518613 -0.562250 0.232053 -1.169593 1.913905 \n", + "1 0.246290 -0.499413 -0.827996 -2.490847 0.018145 \n", + "2 0.196879 0.021231 1.109334 -0.268738 0.088358 \n", + "3 1.691550 -0.346811 0.487926 -0.809251 0.930918 \n", + "4 0.295700 0.227694 1.840403 0.451946 1.281985 \n", + "\n", + " total_phenols flavanoids nonflavanoid_phenols proanthocyanins \\\n", + "0 0.808997 1.034819 -0.659563 1.224884 \n", + "1 0.568648 0.733629 -0.820719 -0.544721 \n", + "2 0.808997 1.215533 -0.498407 2.135968 \n", + "3 2.491446 1.466525 -0.981875 1.032155 \n", + "4 0.808997 0.663351 0.226796 0.401404 \n", + "\n", + " color_intensity hue od280/od315_of_diluted_wines proline \n", + "0 0.251717 0.362177 1.847920 1.013009 \n", + "1 -0.293321 0.406051 1.113449 0.965242 \n", + "2 0.269020 0.318304 0.788587 1.395148 \n", + "3 1.186068 -0.427544 1.184071 2.334574 \n", + "4 -0.319276 0.362177 0.449601 -0.037874 \n" + ] + } + ], "source": [ "# Select predictors (excluding the last column)\n", "predictors = wine_df.iloc[:, :-1]\n", @@ -204,7 +1484,7 @@ "id": "403ef0bb", "metadata": {}, "source": [ - "> Your answer here..." + "Since we are using Euclidean distance to measure how close values are to each other, if we have unstandardized values, then predictors measured with larger numbers will dominate distances, and bias results towards those predictors. Standardizing all predictors allows for equal contribution independent of the scale of values used in predictors." ] }, { @@ -220,7 +1500,7 @@ "id": "fdee5a15", "metadata": {}, "source": [ - "> Your answer here..." + "Class is the categorical response, and so standardizing these values will distort the meaning of the dependent variable we are trying to predict." ] }, { @@ -236,7 +1516,7 @@ "id": "f0676c21", "metadata": {}, "source": [ - "> Your answer here..." + "Having a random seed allow for randomness to be used in the analysis, while also allowing for reproducibility of the results if run at a later point or by someone else. The number doesnt matter, any value will work, since it just identifies the type of randomness that must be replicated." ] }, { @@ -251,7 +1531,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "72c101f2", "metadata": {}, "outputs": [], @@ -260,8 +1540,13 @@ "np.random.seed(123)\n", "\n", "# split the data into a training and testing set. hint: use train_test_split !\n", - "\n", - "# Your code here ..." + "X_train, X_test, y_train, y_test = train_test_split(\n", + " predictors_standardized, \n", + " wine_df['class'],\n", + " train_size = 0.75,\n", + " shuffle = True,\n", + " stratify = wine_df['class']\n", + ")" ] }, { @@ -284,12 +1569,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "08818c64", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best k: 7\n" + ] + } + ], "source": [ - "# Your code here..." + "knn = KNeighborsClassifier()\n", + "\n", + "param_grid = {'n_neighbors': range(1, 51)}\n", + "\n", + "grid_search = GridSearchCV(knn, param_grid, cv=10)\n", + "grid_search.fit(X_train, y_train)\n", + "\n", + "best_k = grid_search.best_params_['n_neighbors']\n", + "best_k\n", + "print(f\"Best k: {grid_search.best_params_['n_neighbors']}\")" ] }, { @@ -305,12 +1607,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "ffefa9f2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test set accuracy with best k=7: 0.9333\n" + ] + } + ], "source": [ - "# Your code here..." + "best_knn = KNeighborsClassifier(n_neighbors=best_k)\n", + "best_knn.fit(X_train, y_train)\n", + "\n", + "best_knn.score(X_test, y_test)\n", + "\n", + "# Evaluate and print the test accuracy\n", + "test_accuracy = best_knn.score(X_test, y_test)\n", + "print(f\"Test set accuracy with best k={best_k}: {test_accuracy:.4f}\")" ] }, { @@ -365,7 +1682,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.10.4", + "display_name": "lcr-env (3.11.13)", "language": "python", "name": "python3" }, @@ -379,12 +1696,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" - }, - "vscode": { - "interpreter": { - "hash": "497a84dc8fec8cf8d24e7e87b6d954c9a18a327edc66feb9b9ea7e9e72cc5c7e" - } + "version": "3.11.13" } }, "nbformat": 4,