{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting Started\n", "\n", "This notebook will show you how to built a complexe pipeline using aikit and how to crossvalidated it" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclassnamesexagesibspparchticketfarecabinembarkedboatbodyhome_dest
01McCarthy, Mr. Timothy Jmale54.0001746351.8625E46SNaN175.0Dorchester, MA
11Fortune, Mr. Markmale64.01419950263.0000C23 C25 C27SNaNNaNWinnipeg, MB
21Sagesser, Mlle. Emmafemale24.000PC 1747769.3000B35C9NaNNaN
33Panula, Master. Urho Abrahammale2.041310129539.6875NaNSNaNNaNNaN
41Maioni, Miss. Robertafemale16.00011015286.5000B79S8NaNNaN
53Waelens, Mr. Achillemale22.0003457679.0000NaNSNaNNaNAntwerp, Belgium / Stanton, OH
63Reed, Mr. James GeorgemaleNaN003623167.2500NaNSNaNNaNNaN
71Swift, Mrs. Frederick Joel (Margaret Welles Ba...female48.0001746625.9292D17S8NaNBrooklyn, NY
81Smith, Mrs. Lucien Philip (Mary Eloise Hughes)female18.0101369560.0000C31S6NaNHuntington, WV
91Rowe, Mr. Alfred Gmale33.00011379026.5500NaNSNaN109.0London
\n", "
" ], "text/plain": [ " pclass name sex age \\\n", "0 1 McCarthy, Mr. Timothy J male 54.0 \n", "1 1 Fortune, Mr. Mark male 64.0 \n", "2 1 Sagesser, Mlle. Emma female 24.0 \n", "3 3 Panula, Master. Urho Abraham male 2.0 \n", "4 1 Maioni, Miss. Roberta female 16.0 \n", "5 3 Waelens, Mr. Achille male 22.0 \n", "6 3 Reed, Mr. James George male NaN \n", "7 1 Swift, Mrs. Frederick Joel (Margaret Welles Ba... female 48.0 \n", "8 1 Smith, Mrs. Lucien Philip (Mary Eloise Hughes) female 18.0 \n", "9 1 Rowe, Mr. Alfred G male 33.0 \n", "\n", " sibsp parch ticket fare cabin embarked boat body \\\n", "0 0 0 17463 51.8625 E46 S NaN 175.0 \n", "1 1 4 19950 263.0000 C23 C25 C27 S NaN NaN \n", "2 0 0 PC 17477 69.3000 B35 C 9 NaN \n", "3 4 1 3101295 39.6875 NaN S NaN NaN \n", "4 0 0 110152 86.5000 B79 S 8 NaN \n", "5 0 0 345767 9.0000 NaN S NaN NaN \n", "6 0 0 362316 7.2500 NaN S NaN NaN \n", "7 0 0 17466 25.9292 D17 S 8 NaN \n", "8 1 0 13695 60.0000 C31 S 6 NaN \n", "9 0 0 113790 26.5500 NaN S NaN 109.0 \n", "\n", " home_dest \n", "0 Dorchester, MA \n", "1 Winnipeg, MB \n", "2 NaN \n", "3 NaN \n", "4 NaN \n", "5 Antwerp, Belgium / Stanton, OH \n", "6 NaN \n", "7 Brooklyn, NY \n", "8 Huntington, WV \n", "9 London " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from aikit.datasets.datasets import load_dataset, DatasetEnum\n", "Xtrain, y_train, _ ,_ , _ = load_dataset(DatasetEnum.titanic)\n", "Xtrain.head(10)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 1, 0, 1, 0, 0, 1, 1, 0], dtype=int64)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train[0:10]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "%3\r\n", "\r\n", "\r\n", "enc\r\n", "\r\n", "enc\r\n", "\r\n", "\r\n", "imp\r\n", "\r\n", "imp\r\n", "\r\n", "\r\n", "enc->imp\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "rf\r\n", "\r\n", "rf\r\n", "\r\n", "\r\n", "imp->rf\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "sel\r\n", "\r\n", "sel\r\n", "\r\n", "\r\n", "sel->enc\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "vect\r\n", "\r\n", "vect\r\n", "\r\n", "\r\n", "vect->rf\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "\r\n" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from aikit.pipeline import GraphPipeline\n", "from aikit.transformers import ColumnsSelector, NumericalEncoder, NumImputer, CountVectorizerWrapper\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "text_cols = [\"name\",\"ticket\"]\n", "non_text_cols = [c for c in Xtrain.columns if c not in text_cols]\n", "\n", "gpipeline = GraphPipeline(models = {\n", " \"sel\":ColumnsSelector(columns_to_use=non_text_cols),\n", " \"enc\":NumericalEncoder(columns_to_use=\"object\"),\n", " \"imp\":NumImputer(),\n", " \"vect\":CountVectorizerWrapper(analyzer=\"word\",columns_to_use=text_cols),\n", " \"rf\":RandomForestClassifier(n_estimators=100, random_state=123)\n", " },\n", " edges = [(\"sel\",\"enc\",\"imp\",\"rf\"),(\"vect\",\"rf\")])\n", "\n", "gpipeline.fit(Xtrain,y_train)\n", "gpipeline.graphviz" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cv 0 started\n", "\n", "cv 1 started\n", "\n", "cv 2 started\n", "\n", "cv 3 started\n", "\n", "cv 4 started\n", "\n", "cv 5 started\n", "\n", "cv 6 started\n", "\n", "cv 7 started\n", "\n", "cv 8 started\n", "\n", "cv 9 started\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Done 10 out of 10 | elapsed: 13.3s finished\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
test_accuracytest_roc_auctest_neg_log_losstrain_accuracytrain_roc_auctrain_neg_log_lossfit_timescore_timen_test_samplesfold_nb
00.9809520.998095-0.1168521.01.0-0.0439280.8412720.1336421050
10.9619050.969512-0.4847641.01.0-0.0420550.9035840.1266631051
20.9619050.994474-0.1392151.01.0-0.0418640.7477450.1177021052
30.9904761.000000-0.1025791.01.0-0.0423160.7354970.1206381053
40.9523810.994284-0.1300341.01.0-0.0440320.7725310.1192711054
50.9619050.996570-0.1341161.01.0-0.0414990.7255580.1266951055
60.9714290.998476-0.1406611.01.0-0.0472360.7610990.1166981056
70.9619050.995617-0.1553531.01.0-0.0409470.7345430.1112881057
80.9615380.994386-0.1326301.01.0-0.0413350.7404710.1117821048
90.9807690.996903-0.1502821.01.0-0.0448300.7491130.1127771049
\n", "
" ], "text/plain": [ " test_accuracy test_roc_auc test_neg_log_loss train_accuracy \\\n", "0 0.980952 0.998095 -0.116852 1.0 \n", "1 0.961905 0.969512 -0.484764 1.0 \n", "2 0.961905 0.994474 -0.139215 1.0 \n", "3 0.990476 1.000000 -0.102579 1.0 \n", "4 0.952381 0.994284 -0.130034 1.0 \n", "5 0.961905 0.996570 -0.134116 1.0 \n", "6 0.971429 0.998476 -0.140661 1.0 \n", "7 0.961905 0.995617 -0.155353 1.0 \n", "8 0.961538 0.994386 -0.132630 1.0 \n", "9 0.980769 0.996903 -0.150282 1.0 \n", "\n", " train_roc_auc train_neg_log_loss fit_time score_time n_test_samples \\\n", "0 1.0 -0.043928 0.841272 0.133642 105 \n", "1 1.0 -0.042055 0.903584 0.126663 105 \n", "2 1.0 -0.041864 0.747745 0.117702 105 \n", "3 1.0 -0.042316 0.735497 0.120638 105 \n", "4 1.0 -0.044032 0.772531 0.119271 105 \n", "5 1.0 -0.041499 0.725558 0.126695 105 \n", "6 1.0 -0.047236 0.761099 0.116698 105 \n", "7 1.0 -0.040947 0.734543 0.111288 105 \n", "8 1.0 -0.041335 0.740471 0.111782 104 \n", "9 1.0 -0.044830 0.749113 0.112777 104 \n", "\n", " fold_nb \n", "0 0 \n", "1 1 \n", "2 2 \n", "3 3 \n", "4 4 \n", "5 5 \n", "6 6 \n", "7 7 \n", "8 8 \n", "9 9 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from aikit.cross_validation import cross_validation\n", "from sklearn.model_selection import StratifiedKFold\n", "\n", "cv = StratifiedKFold(10, shuffle=True, random_state=123)\n", "\n", "cv_res, yhat_proba = cross_validation(gpipeline, Xtrain, y_train,cv=cv, scoring=[\"accuracy\", \"roc_auc\", \"neg_log_loss\"], return_predict=True, method=\"predict_proba\")\n", "\n", "cv_res" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
01.000.00
10.880.12
20.050.95
30.930.07
40.070.93
51.000.00
61.000.00
70.030.97
80.060.94
90.980.02
\n", "
" ], "text/plain": [ " 0 1\n", "0 1.00 0.00\n", "1 0.88 0.12\n", "2 0.05 0.95\n", "3 0.93 0.07\n", "4 0.07 0.93\n", "5 1.00 0.00\n", "6 1.00 0.00\n", "7 0.03 0.97\n", "8 0.06 0.94\n", "9 0.98 0.02" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yhat_proba.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using cross_validation you get in one call :\n", "* both train and test score\n", "* all the metrics\n", "* the probabilities predicted for each observation" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }