Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
976 changes: 976 additions & 0 deletions 01_materials/labs/.ipynb_checkpoints/lab_1-checkpoint.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions 01_materials/labs/lab1test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TF_NUM_INTRAOP_THREADS"] = "1"
os.environ["TF_NUM_INTEROP_THREADS"] = "1"

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

digits = load_digits()
X = digits.data
y = to_categorical(digits.target)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

model = Sequential([
Dense(64, activation="relu", input_shape=(64,)),
Dense(10, activation="softmax")
])
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

print("Starting training…")
model.fit(X_train, y_train, epochs=2, batch_size=32, validation_split=0.2)
print("Done")
1,117 changes: 1,002 additions & 115 deletions 01_materials/labs/lab_1.ipynb

Large diffs are not rendered by default.

147 changes: 121 additions & 26 deletions 01_materials/labs/lab_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,9 +36,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ0AAAEpCAYAAACJL3coAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAF3lJREFUeJzt3Q10FNX5x/EnEAIB8wYaIQUCIspLwPAuDQoKainQQk9pa6EFraJpEKi1cuxpBWohUqu1iuWtNHBECPVY0HoKERCSKqYQaNogbQB5B5FqIQlQBZP5n+f+u2k2L5CbZDfZme/nnDHsOLtzF7K/vXNn7jxhjuM4AgC11Ky2GwIAoQHAGj0NAFYIDQBWCA0AVggNAFYIDQBWCA0AVggNAFYIjQBZuXKlhIWFyZEjR8St6vMeR4wYIUlJSQ3ani5dusjUqVMb9DVRFaEBiMipU6dk8uTJcvPNN0tUVJTExsbK4MGDZdWqVcJMC3/hlR6jgXznO9+Rb33rW9KyZUv+TkPAxx9/LCdOnJCvf/3r0rlzZ7l8+bJs3rzZ9FwKCwtlwYIFjd3EJoPQCJDmzZubBaGhb9++sn37dr9106dPl3HjxskLL7wgTz31FP+e/8XhSRCP9/WYe+zYseaXc+DAgRIZGSl9+vQp/2X9wx/+YB63atVKBgwYIH/961/9XvPvf/+7+ea74YYbzDbt27eX+++/Xz755JMq+/ftQ7fr1q2bLF26VObOnWvaVNnq1avN/rQ9bdu2NT2k48eP1+l9v/766zJmzBhJSEgwvSzdt37gSktLq91+9+7d8sUvftHsu2vXrrJkyZIq23z22WcyZ84cufHGG81rdurUSR5//HGz/mo++OADs9SV/ptdvHhRLl26VOfXcBt6GkF28OBB+fa3vy0PPfSQOYb+5S9/ab7N9MPy4x//WL7//e+b7dLT0+Ub3/iG6Ro3a/b/2a7d5UOHDsl9991nAuP999+XZcuWmZ+5ubnlgaBh86UvfUk6dOgg8+bNMx/Yn/3sZ3LddddVac/8+fPlpz/9qdnXAw88IP/617/kxRdflNtvv928jh7b24blNddcI48++qj5+fbbb8uTTz4pxcXF8swzz/hte/bsWfnyl79s9n3vvffK73//e0lNTZWIiAgThqqsrEy+8pWvyDvvvCPTpk2Tnj17SkFBgfzqV7+S/fv3y4YNG67YnpEjR5qftR2s/c9//iMXLlyQ8+fPS3Z2tmRkZMjQoUNNqOG/9H4aaHgZGRl6nxLn8OHD5esSExPNuh07dpSvy8rKMusiIyOdo0ePlq9funSpWb9t27bydRcvXqyyn7Vr15rtcnJyyteNGzfOad26tXPy5MnydQcOHHDCw8PNtj5Hjhxxmjdv7syfP9/vNQsKCsy2ldfX5j1W18aHHnrItOfTTz8tXzd8+HDz3GeffbZ83WeffeYkJyc78fHxzqVLl8y6l19+2WnWrJnz5z//2e81lyxZYp7/7rvv+v39TpkyxW87XadLbaWnp5vX9S0jR450jh07VuvnewGHJ0HWq1cv883lM2TIEPPzzjvvNANwlddrz8Kn4rfdp59+agbvbr31VvN4z5495qf2KrZs2SLjx483hwg+2rUfPXq0X1v0cEi/yfWbXl/Lt2gvpnv37rJt2zbr91exjSUlJeb1brvtNtPF/+c//+m3bXh4uOlx+WgPQx+fOXPGHLaoV1991fQuevTo4ddG/ftSV2uj9jBsTglrj0d7dGvWrDE9Ql/vA//D4UmQVQwGFRMTY37qcXp167UL7/Pvf//bHG5kZmaaD1ZFRUVF5qeu119yDYnKKq87cOCAOZ2oAVGdFi1aWL47MYdKP/nJT8xhiR6SVNdGHw21Nm3a+K276aabzE/9oGsgahv/8Y9/VHtopSr/PdRXYmKiWXwBoodEo0aNMoeJHKL8P0IjyGo6o1LT+orXCGiPYMeOHfKjH/1IkpOTzZiB9hR0/EJ/2tLn6DjIxo0bq92/vr6Nc+fOyfDhwyU6OtqMoeggqA7Eai9o9uzZdW6jDg4/99xz1f7/ymHb0PQU7PLlyyUnJ0fuueeegO4rVBAaIUJ7HFu3bjU9DR1Y9NFv4ori4+PNB1UHXCurvE4/1BpKetbC9w1fH3rGRs/k6GGPDqT6HD58uMYLqnTQsWJvQwc3fWctfG3829/+ZgY0qzvzE2i+Q5PKvSQvY0wjRPh6ApWvTnz++eerbKfdaT2roB/KioGhPYqKvva1r5ntNYgqv64+ru5Urm0b9VTlb37zm2q3//zzz82p4Irb6mM9FNFTwL7e1cmTJ823fU1nOhrilKueNarOihUrTFj179//qq/hFfQ0QoR2+fXb+xe/+IW5WvELX/iCvPXWW9V+i+v1GPr/UlJSzClMHRxdtGiRmeuRn59fvp1+i//85z+XJ554wowh6OCpXkKtr7l+/XpzPP/YY4/Vuo16vUVcXJxMmTJFZsyYYT5sL7/8co2XYeuYxsKFC82+taezbt060z49jewbT9Era/VU7MMPP2wGPfU96fvRQVVdn5WVZa5Hqe8pVz31/O6775pDPR130vGj1157TXbt2iWPPPJItWNEntXYp2+8dsp1zJgxVbbV7dLS0vzW6fN0/TPPPFO+7sSJE86ECROc2NhYJyYmxpk4caJz6tQps92cOXP8nr9161anX79+TkREhNOtWzfnt7/9rfPDH/7QadWqVZX9v/baa86wYcOcNm3amKVHjx6mPYWFhdbvUU+B3nrrreYUckJCgvP444+Xn1auePpYT7n27t3bycvLc4YOHWrapX8/ixYtqrIfPf26cOFCs33Lli2duLg4Z8CAAc68efOcoqKiBjnl+tZbbzljx441bW7RooUTFRXlpKSkmPdYVlZ21ed7SZj+p7GDC8GhPQk9u1F5HASwwZiGS1W+tkCD4k9/+pOZkg7UBz0Nl9JLyH3zVI4ePSqLFy82czX00vCarssAaoOBUJfSAb21a9fK6dOnzSQvvQpVp3cTGKgvehoArDCmAcAKoQGgaY9p6FwCvVJRLyJqjMuCAVRPr77Qmcl60Z3vHi5NIjQ0MAI9yQhA3eld2zp27Nh0QkN7GL6G6aXRbqYzO4OtutvlBVJDlyGoDd/dzYJp0qRJ4nbFxcXmC933GW0yoeE7JNHAcHtoeOFO5I1x8+TGuK+F239XK7rasAEDoQCsEBoArBAaAKwQGgCsEBoArBAaAKwQGgACHxovvfSSuVu03vVai/rs3LmzLi8DwAuhoTd/1TqdWpBX61nccsstph5EQxetAeCS0NCiNQ8++KApQqwlBvWy5datW8vvfve7wLQQQOiGhtal0BqbWlej/AWaNTOP33vvvUC0D0ATYzX3RAvvas2J66+/3m+9Pq5c3NdH70upi0/l+p4AQkvAz56kp6ebYsa+hWnxgIdC49prrzWzGj/66CO/9fq4ffv21T5Hq3dpHUzfolPiAXgkNCIiIkyNTS1EXPFOXPpY73Zd0/Rw3zR4L0yHB9zO+n4aerpVa3Vq/czBgwebAsRahFfPpgBwP+vQ+OY3v2kqbD/55JOmpkZycrJs2rSpyuAoAHeq0527pk+fbhYA3sPcEwBWCA0AVggNAFYIDQBWCA0AVggNAFYIDQBWCA0AVsIcLRUdRDo1Xme76uQ1t89DWblyZdD3GRsbG9T9TZgwQbwgyB+TRlHbzyY9DQBWCA0AVggNAFYIDQBWCA0AVggNAFYIDQBWCA0AVggNAIENjZycHBk3bpwkJCRIWFiYbNiwwfYlAHgpNPTO41r0WSvHA/Ae6xsLjx492iwAvIkxDQCBL2FggwLQgLtQABpA0woNCkAD7hLwwxMtAK0LAI+Gxvnz5+XgwYPljw8fPiz5+fnStm1b6dy5c0O3D0Coh0ZeXp7ccccdflXklVaSb4zb2wFo4qExYsQIT9wvEUD1uE4DgBVCA4AVQgOAFUIDgBVCA4AVQgOAFUIDgBVCA0DTmnviZVOnTg36PufOnRvU/WnB4GDjyuPGRU8DgBVCA4AVQgOAFUIDgBVCA4AVQgOAFUIDgBVCA4AVQgNA4EIjPT1dBg0aJFFRURIfHy/jx4+XwsJCuz0C8E5oZGdnS1pamuTm5srmzZvl8uXLcvfdd5ui0AC8wWruyaZNm6rMAdAex+7du+X2229v6LYBcNuEtaKiIvNTa57UhFqugLvUeSC0rKxMZs2aJSkpKZKUlHTFcRCdCelbOnXqVNddAgjl0NCxjb1790pmZuYVt6OWK+AudTo8mT59urz55puSk5MjHTt2vOK21HIFPBwaWlntkUcekfXr18v27dula9eugWsZgNAPDT0kWbNmjbz++uvmWo3Tp0+b9TpWERkZGag2AgjVMY3FixebMyZaz7VDhw7ly7p16wLXQgChfXgCwNuYewLACqEBwAqhAcAKoQHACqEBwAqhAcAKoQHACqEBwAoFoF0mOTk5qPuLjY2VYOvSpUvQ94n/oacBwAqhAcAKoQHACqEBwAqhAcAKoQHACqEBwAqhAcAKoQEgsPcI7du3r0RHR5tl6NChsnHjRrs9AvBOaGiNk6efftrUbs3Ly5M777xTvvrVr8r7778fuBYCCN25J+PGjfN7PH/+fNP70CryvXv3bui2AXDThLXS0lJ59dVX5cKFC+YwpSYUgAY8PhBaUFAg11xzjSm3+PDDD5tqa7169apxewpAAx4PjZtvvlny8/PlL3/5i6SmpsqUKVNk3759NW5PAWjA44cnERERcuONN5o/DxgwQHbt2iW//vWvZenSpdVuTwFowF3qfZ1GWVmZGbcA4A1WPQ091Bg9erR07txZSkpKTDForR6flZUVuBYCCN3QOHPmjHz3u9+VDz/80FSK1wu9NDDuuuuuwLUQQOiGxooVKwLXEgAhgbknAKwQGgCsEBoArBAaAKwQGgCsEBoArBAaAKxQy9Vlxo8fH9T96RXBwTZixIig71MnaQZblyZas5aeBgArhAYAK4QGACuEBgBCA0Dg0NMAYIXQAGCF0ABghdAAYIXQABC80NC6rmFhYTJr1qz6vAwAL4SG1jvRWid6c2EA3lGn0Dh//rxMmjRJli9fLnFxcQ3fKgDuCo20tDQZM2aMjBo16qrbaiGl4uJivwWAh6bGZ2Zmyp49e8zhSW1oAeh58+bVpW0AQr2ncfz4cZk5c6a88sor0qpVq1o9hwLQgId7Grt37zZV1vr371++rrS0VHJycmTRokXmUKR58+Z+z6EANODh0Bg5cqQUFBT4rbvvvvukR48eMnv27CqBAcDjoREVFSVJSUl+69q0aSPt2rWrsh6AO3FFKIDg3li4MW4sC6Dx0NMAYIXQAGCF0ABghdAAYIXQAGCF0ABghdAAYCXMcRxHgkinxsfExEhRUZFER0cHc9dwiWAXuVbnzp0L+j63B/kaqNp+NulpALBCaACwQmgAsEJoALBCaACwQmgAsEJoALBCaACwQmgAsEJoAAhcaMydO9cUfK646J3IAXiH9T1Ce/fuLVu2bPnfC4TX+zajAEKI9SdeQ6J9+/aBaQ0A941pHDhwQBISEuSGG24wleOPHTt2xe0pAA14ODSGDBkiK1eulE2bNsnixYvl8OHDctttt0lJSckVC0DrdFvf0qlTp4ZoN4BQvJ+G3mMgMTFRnnvuOfne975XY09Dl4pz9jU4uJ8G6or7aTTu/TTqNYoZGxsrN910kxw8eLDGbSgADbhLva7TOH/+vHzwwQfSoUOHhmsRAPeExmOPPSbZ2dly5MgR2bFjh0yYMMFUir/33nsD10IATYrV4cmJEydMQHzyySdy3XXXybBhwyQ3N9f8GYA3WIVGZmZm4FoCICQw9wSAFUIDgBVCA4AVQgOAFUIDgBVCA4AVQgOAFe6g46ICvo2xz/z8fPHC32tycnLQ99lU0dMAYIXQAGCF0ABghdAAYIXQAGCF0ABghdAAQGgACBx6GgCsEBoAAhsaJ0+elMmTJ0u7du0kMjJS+vTpI3l5ebYvA8ALc0/Onj0rKSkpcscdd8jGjRvNDYW1TGNcXFzgWgggdENj4cKFpjpaRkZG+bquXbsGol0A3HB48sYbb8jAgQNl4sSJEh8fL/369ZPly5df8TkUgAY8HBqHDh0yhZ+7d+8uWVlZkpqaKjNmzJBVq1bV+BwKQAMeDo2ysjLp37+/LFiwwPQypk2bJg8++KAsWbKkxuc88cQTpqCsbzl+/HhDtBtAKISG1mzt1auX37qePXvKsWPHrlgAWitQV1wAeCQ09MxJYWGh37r9+/dLYmJiQ7cLgBtC4wc/+IGp3aqHJwcPHpQ1a9bIsmXLJC0tLXAtBBC6oTFo0CBZv369rF27VpKSkuSpp56S559/XiZNmhS4FgII7RsLjx071iwAvIm5JwCsEBoArBAaAKwQGgCsEBoArBAaAKwQGgCsUAA6gPTCt2ALdkHmLl26SLDNmjUr6PucO3du0PfZVNHTAGCF0ABghdAAYIXQAGCF0ABghdAAYIXQAGCF0ABghdAAELjQ0Kv/wsLCqizcIxTwDqvLyHft2iWlpaXlj/fu3St33XWXqbgGwBusQkMLPlf09NNPS7du3WT48OEN3S4AbhvTuHTpkqxevVruv/9+c4gCwBvqPMt1w4YNcu7cOZk6depVC0Dr4lNcXFzXXQII5Z7GihUrZPTo0ZKQkHDF7SgADbhLnULj6NGjsmXLFnnggQeuui0FoAF3qdPhSUZGhsTHx8uYMWOuuq0WgNYFgEd7GmVlZSY0pkyZIuHh3PgL8Brr0NDDkmPHjpmzJgC8x7qrcPfdd4vjOIFpDYAmj7knAKwQGgCsEBoArBAaAKwQGgCsEBoArBAaAKwE/ZJO3zUeXpjtevny5aDvU6/YDabPP/9cgq3irOlg8cLva/F/3+PVrsMKc4J8pdaJEyekU6dOwdwlAAvHjx+Xjh07Np3Q0G/CU6dOSVRUlNXNezQFNWz0DUVHR4tb8T7dozjEfmc1CkpKSsztLpo1a9Z0Dk+0MVdKsavRv/xQ+AeoL96ne0SH0O9sTEzMVbdhIBSAFUIDgDtDQ2/kM2fOHNff0If36R4tXfo7G/SBUAChLWR6GgCaBkIDgBVCA4AVQgOAO0PjpZdeMlXrW7VqJUOGDJGdO3eKW2hBqUGDBpmrZLU0xPjx46WwsFDcTmsB61XBs2bNErc5efKkTJ48Wdq1ayeRkZHSp08fycvLEzcIidBYt26dPProo+b01Z49e+SWW26Re+65R86cOSNukJ2dLWlpaZKbmyubN282E930Bs4XLlwQt9q1a5csXbpU+vbtK25z9uxZSUlJkRYtWsjGjRtl37598uyzz0pcXJy4ghMCBg8e7KSlpZU/Li0tdRISEpz09HTHjc6cOaOnwZ3s7GzHjUpKSpzu3bs7mzdvdoYPH+7MnDnTcZPZs2c7w4YNc9yqyfc0tDr97t27ZdSoUX7zV/Txe++9J25UVFRkfrZt21bcSHtVWp2v4r+pm7zxxhsycOBAmThxojnc7NevnyxfvlzcosmHxscffyylpaVy/fXX+63Xx6dPnxa30VnAeoyv3dukpCRxm8zMTHOIqeM4bnXo0CFZvHixdO/eXbKysiQ1NVVmzJghq1atEjegrmIT/Bbeu3evvPPOO+I2OkV85syZZtxGB7TdqqyszPQ0FixYYB5rT0P/TZcsWWLKmYa6Jt/TuPbaa6V58+by0Ucf+a3Xx+3btxc3mT59urz55puybdu2et0+oKnSw0wdvO7fv7+pA6yLDgK/8MIL5s/ao3SDDh06SK9evfzW9ezZ05QzdYMmHxoREREyYMAA2bp1q1+S6+OhQ4eKG+j0Hw2M9evXy9tvvy1du3YVNxo5cqQUFBRIfn5++aLfyJMmTTJ/1i8HN0hJSalyynz//v2SmJgoruCEgMzMTKdly5bOypUrnX379jnTpk1zYmNjndOnTztukJqa6sTExDjbt293Pvzww/Ll4sWLjtu58ezJzp07nfDwcGf+/PnOgQMHnFdeecVp3bq1s3r1ascNQiI01Isvvuh07tzZiYiIMKdgc3NzHbfQ7K5uycjIcNzOjaGh/vjHPzpJSUnmy65Hjx7OsmXLHLdgajwAd41pAGhaCA0AVggNAFYIDQBWCA0AVggNAFYIDQBWCA0AVggNAFYIDQBWCA0AVggNAGLj/wAZzdjvXmnBEAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_index = 45\n",
"plt.figure(figsize=(3, 3))\n",
Expand All @@ -58,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -91,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -101,18 +112,43 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot(n_classes=10, y=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot(n_classes=10, y=[0, 4, 9, 1])"
]
Expand Down Expand Up @@ -143,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {
"collapsed": false
},
Expand All @@ -164,9 +200,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n"
]
}
],
"source": [
"print(softmax([10, 2, -3]))"
]
Expand All @@ -181,9 +225,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n",
" [2.47262316e-03 9.97527377e-01 1.38536042e-11]]\n"
]
}
],
"source": [
"X = np.array([[10, 2, -3],\n",
" [-1, 5, -20]])\n",
Expand All @@ -199,18 +252,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n"
]
}
],
"source": [
"print(np.sum(softmax([10, 2, -3])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax of 2 vectors:\n",
"[[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n",
" [2.47262316e-03 9.97527377e-01 1.38536042e-11]]\n"
]
}
],
"source": [
"print(\"softmax of 2 vectors:\")\n",
"X = np.array([[10, 2, -3],\n",
Expand All @@ -227,9 +298,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 1.]\n"
]
}
],
"source": [
"print(np.sum(softmax(X), axis=1))"
]
Expand All @@ -251,9 +330,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.01005033585350145\n"
]
}
],
"source": [
"def nll(Y_true, Y_pred):\n",
" Y_true = np.asarray(Y_true)\n",
Expand All @@ -279,9 +366,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.605170185988091\n"
]
}
],
"source": [
"print(nll([1, 0, 0], [0.01, 0.01, .98]))"
]
Expand Down Expand Up @@ -822,7 +917,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "DSI_participant",
"language": "python",
"name": "python3"
},
Expand All @@ -836,7 +931,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
"version": "3.10.17"
}
},
"nbformat": 4,
Expand Down
Loading