Problem Set 11¶

First the exercise:

• What is the maximum depth of a decision tree trained on $N$ samples? The decision tree must make a proper split at each node, so the size of each node must reduce by at least one as we move down one level. So the maximum depth of a decision tree is $N-1$.
• If we train a decision tree to an arbitrary depth, what will be the training error? Assuming the training data assigns unique labels to samples with identical features, this will be Zero. If we train a decision tree to arbitrary depth we will end up with a tree where each node contains samples with identical features. If each of these samples has the same label than any of the standard rules (voting, averaging) will return the correct response.
• How can we alter a loss function to help regularize a decision tree? One of the simplest ways is to add to our loss function an increasing function of the depth of the node. For example, we could just add $\lambda |D|$ or perhaps $\lambda 2^|D|$ where $\lambda$ is an appropriate hyperparameter (probably very small). One should choose so that growth of this regularization term so that it will not dominate the unregularized cost function when obtaining improvements at the desired rate.

Python Lab¶

Now let us load our standard libraries.

In [249]:
import numpy as np
import pandas as pd


Let us load the credit card dataset and extract a small dataframe of numerical features to test on.

In [250]:
big_df = pd.read_csv("UCI_Credit_Card.csv")

In [251]:
big_df.head()

Out[251]:
ID LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 PAY_4 ... BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6 default.payment.next.month
0 1 20000.0 2 2 1 24 2 2 -1 -1 ... 0.0 0.0 0.0 0.0 689.0 0.0 0.0 0.0 0.0 1
1 2 120000.0 2 2 2 26 -1 2 0 0 ... 3272.0 3455.0 3261.0 0.0 1000.0 1000.0 1000.0 0.0 2000.0 1
2 3 90000.0 2 2 2 34 0 0 0 0 ... 14331.0 14948.0 15549.0 1518.0 1500.0 1000.0 1000.0 1000.0 5000.0 0
3 4 50000.0 2 2 1 37 0 0 0 0 ... 28314.0 28959.0 29547.0 2000.0 2019.0 1200.0 1100.0 1069.0 1000.0 0
4 5 50000.0 1 2 1 57 -1 0 -1 0 ... 20940.0 19146.0 19131.0 2000.0 36681.0 10000.0 9000.0 689.0 679.0 0

5 rows × 25 columns

In [252]:
len(big_df)

Out[252]:
30000
In [253]:
len(big_df.dropna())

Out[253]:
30000
In [254]:
df = big_df.drop(labels = ['ID'], axis = 1)

In [255]:
labels = df['default.payment.next.month']
df.drop('default.payment.next.month', axis = 1, inplace = True)

In [256]:
num_samples = 25000

In [257]:
train_x, train_y = df[0:num_samples], labels[0:num_samples]

In [258]:
test_x, test_y = df[num_samples:], labels[num_samples:]

In [259]:
test_x.head()

Out[259]:
LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 ... BILL_AMT3 BILL_AMT4 BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 PAY_AMT6
25000 410000.0 1 1 1 38 -1 -1 -1 -1 -2 ... 35509.0 0.0 0.0 0.0 0.0 35509.0 0.0 0.0 0.0 0.0
25001 260000.0 1 2 2 35 0 0 0 0 0 ... 297313.0 276948.0 2378.0 -2709.0 12325.0 6633.0 6889.0 1025.0 2047.0 194102.0
25002 50000.0 1 2 1 40 0 0 0 0 0 ... 11353.0 12143.0 11753.0 11922.0 1200.0 4000.0 2000.0 2000.0 1000.0 1000.0
25003 360000.0 1 3 1 37 -1 -1 -1 -2 -2 ... 0.0 0.0 0.0 0.0 303.0 0.0 0.0 0.0 0.0 860.0
25004 50000.0 1 3 1 49 0 0 0 0 0 ... 50076.0 48995.0 19780.0 15102.0 2000.0 5000.0 2305.0 3000.0 559.0 3000.0

5 rows × 23 columns

In [260]:
train_y.head()

Out[260]:
0    1
1    1
2    0
3    0
4    0
Name: default.payment.next.month, dtype: int64

Now let us write our transformation function.

In [264]:
class bin_transformer(object):

def __init__(self, df, num_quantiles = 2):
self.quantiles = df.quantile(np.linspace(1./num_quantiles, 1.-1./num_quantiles,num_quantiles-1))

def transform(self, df):
new = pd.DataFrame()
fns = {}
for col_name in df.axes[1]:
for ix, q in self.quantiles.iterrows():
quart = q[col_name]
new[col_name+str(ix)] = (df[col_name] >= quart)
fns[col_name+str(ix)] =(col_name, lambda x: x[col_name]>=quart)
return new, fns

In [265]:
transformer = bin_transformer(df,5)

In [266]:
train_x_t, tr_fns = transformer.transform(train_x)

In [267]:
test_x_t, test_fns = transformer.transform(test_x)

In [268]:
train_x_t.head()

Out[268]:
LIMIT_BAL0.2 LIMIT_BAL0.4 LIMIT_BAL0.6 LIMIT_BAL0.8 SEX0.2 SEX0.4 SEX0.6 SEX0.8 EDUCATION0.2 EDUCATION0.4 ... PAY_AMT40.6 PAY_AMT40.8 PAY_AMT50.2 PAY_AMT50.4 PAY_AMT50.6 PAY_AMT50.8 PAY_AMT60.2 PAY_AMT60.4 PAY_AMT60.6 PAY_AMT60.8
0 False False False False True True True True True True ... False False True False False False True False False False
1 True True False False True True True True True True ... False False True False False False True True False False
2 True False False False True True True True True True ... False False True True False False True True True True
3 True False False False True True True True True True ... False False True True False False True True False False
4 True False False False True False False False True True ... True True True False False False True False False False

5 rows × 92 columns

In [269]:
tr_fns

Out[269]:
{'AGE0.2': ('AGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'AGE0.4': ('AGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'AGE0.6': ('AGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'AGE0.8': ('AGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT10.2': ('BILL_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT10.4': ('BILL_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT10.6': ('BILL_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT10.8': ('BILL_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT20.2': ('BILL_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT20.4': ('BILL_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT20.6': ('BILL_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT20.8': ('BILL_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT30.2': ('BILL_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT30.4': ('BILL_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT30.6': ('BILL_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT30.8': ('BILL_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT40.2': ('BILL_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT40.4': ('BILL_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT40.6': ('BILL_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT40.8': ('BILL_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT50.2': ('BILL_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT50.4': ('BILL_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT50.6': ('BILL_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT50.8': ('BILL_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT60.2': ('BILL_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT60.4': ('BILL_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT60.6': ('BILL_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'BILL_AMT60.8': ('BILL_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'EDUCATION0.2': ('EDUCATION',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'EDUCATION0.4': ('EDUCATION',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'EDUCATION0.6': ('EDUCATION',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'EDUCATION0.8': ('EDUCATION',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'LIMIT_BAL0.2': ('LIMIT_BAL',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'LIMIT_BAL0.4': ('LIMIT_BAL',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'LIMIT_BAL0.6': ('LIMIT_BAL',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'LIMIT_BAL0.8': ('LIMIT_BAL',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'MARRIAGE0.2': ('MARRIAGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'MARRIAGE0.4': ('MARRIAGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'MARRIAGE0.6': ('MARRIAGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'MARRIAGE0.8': ('MARRIAGE',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_00.2': ('PAY_0',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_00.4': ('PAY_0',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_00.6': ('PAY_0',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_00.8': ('PAY_0',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_20.2': ('PAY_2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_20.4': ('PAY_2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_20.6': ('PAY_2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_20.8': ('PAY_2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_30.2': ('PAY_3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_30.4': ('PAY_3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_30.6': ('PAY_3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_30.8': ('PAY_3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_40.2': ('PAY_4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_40.4': ('PAY_4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_40.6': ('PAY_4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_40.8': ('PAY_4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_50.2': ('PAY_5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_50.4': ('PAY_5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_50.6': ('PAY_5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_50.8': ('PAY_5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_60.2': ('PAY_6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_60.4': ('PAY_6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_60.6': ('PAY_6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_60.8': ('PAY_6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT10.2': ('PAY_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT10.4': ('PAY_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT10.6': ('PAY_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT10.8': ('PAY_AMT1',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT20.2': ('PAY_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT20.4': ('PAY_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT20.6': ('PAY_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT20.8': ('PAY_AMT2',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT30.2': ('PAY_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT30.4': ('PAY_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT30.6': ('PAY_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT30.8': ('PAY_AMT3',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT40.2': ('PAY_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT40.4': ('PAY_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT40.6': ('PAY_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT40.8': ('PAY_AMT4',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT50.2': ('PAY_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT50.4': ('PAY_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT50.6': ('PAY_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT50.8': ('PAY_AMT5',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT60.2': ('PAY_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT60.4': ('PAY_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT60.6': ('PAY_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'PAY_AMT60.8': ('PAY_AMT6',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'SEX0.2': ('SEX',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'SEX0.4': ('SEX',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'SEX0.6': ('SEX',
<function __main__.bin_transformer.transform.<locals>.<lambda>>),
'SEX0.8': ('SEX',
<function __main__.bin_transformer.transform.<locals>.<lambda>>)}

Now let us build some simple loss functions for 1d labels.

In [270]:
def bdd_cross_entropy(pred, label):
return -np.mean(label*np.log(pred+10**(-20)))

In [271]:
def MSE(pred,label):
return np.mean((pred-label)**2)

In [272]:
def acc(pred,label):
return np.mean((pred>=0.5)==(label == 1))


Now let us define the find split function.

In [273]:
def find_split(x, y, loss, verbose = False):
min_ax = None
base_loss = loss(np.mean(y),y)
min_loss = base_loss
N = len(x)
for col_name in x.axes[1]:
num_neg = N - num_pos
if verbose:
print("Column {0} split has improved loss {1}".format(col_name, base_loss-l))
if l < min_loss:
min_loss = l
min_ax = col_name
return min_ax, min_loss


In [278]:
find_split(train_x_t, train_y, MSE, verbose = True)

Column LIMIT_BAL0.2 split has improved loss 0.0032026111833937665
Column LIMIT_BAL0.4 split has improved loss 0.0036568972936314725
Column LIMIT_BAL0.6 split has improved loss 0.002968295613244798
Column LIMIT_BAL0.8 split has improved loss 0.0017932272689534512
Column SEX0.2 split has improved loss nan
Column SEX0.4 split has improved loss 0.0002155159725325817
Column SEX0.6 split has improved loss 0.0002155159725325817
Column SEX0.8 split has improved loss 0.0002155159725325817
Column EDUCATION0.2 split has improved loss 2.3907091916103296e-05
Column EDUCATION0.4 split has improved loss 0.0004640208803457502
Column EDUCATION0.6 split has improved loss 0.0004640208803457502
Column EDUCATION0.8 split has improved loss 0.0004640208803457502
Column MARRIAGE0.2 split has improved loss 3.249086770407139e-05
Column MARRIAGE0.4 split has improved loss 3.249086770407139e-05
Column MARRIAGE0.6 split has improved loss 0.00014480802024710582
Column MARRIAGE0.8 split has improved loss 0.00014480802024710582
Column AGE0.2 split has improved loss 0.00020827243409601848
Column AGE0.4 split has improved loss 9.960157279997883e-06
Column AGE0.6 split has improved loss 4.117947427031976e-05
Column AGE0.8 split has improved loss 8.124558376762514e-05
Column PAY_00.2 split has improved loss 0.0008308647297025351
Column PAY_00.4 split has improved loss 0.0015859674093304799
Column PAY_00.6 split has improved loss 0.0015859674093304799
Column PAY_00.8 split has improved loss 0.0233441901232099
Column PAY_20.2 split has improved loss 0.0001876302154459608
Column PAY_20.4 split has improved loss 0.0013149408868615708
Column PAY_20.6 split has improved loss 0.0013149408868615708
Column PAY_20.8 split has improved loss 0.0013149408868615708
Column PAY_30.2 split has improved loss 0.00021161990136722708
Column PAY_30.4 split has improved loss 0.0014198151850400298
Column PAY_30.6 split has improved loss 0.0014198151850400298
Column PAY_30.8 split has improved loss 0.0014198151850400298
Column PAY_40.2 split has improved loss 0.00014702668290195176
Column PAY_40.4 split has improved loss 0.00110110252994014
Column PAY_40.6 split has improved loss 0.00110110252994014
Column PAY_40.8 split has improved loss 0.00110110252994014
Column PAY_50.2 split has improved loss 0.00011619585491220996
Column PAY_50.4 split has improved loss 0.0010161714350968298
Column PAY_50.6 split has improved loss 0.0010161714350968298
Column PAY_50.8 split has improved loss 0.0010161714350968298
Column PAY_60.2 split has improved loss 8.712841444399877e-05
Column PAY_60.4 split has improved loss 0.0008042740411955129
Column PAY_60.6 split has improved loss 0.0008042740411955129
Column PAY_60.8 split has improved loss 0.0008042740411955129
Column BILL_AMT10.2 split has improved loss 2.7784178714379548e-05
Column BILL_AMT10.4 split has improved loss 2.7479974147037733e-06
Column BILL_AMT10.6 split has improved loss 0.00016119710076237248
Column BILL_AMT10.8 split has improved loss 0.00014701547170556384
Column BILL_AMT20.2 split has improved loss 3.084283179743963e-06
Column BILL_AMT20.4 split has improved loss 8.943772085073798e-06
Column BILL_AMT20.6 split has improved loss 9.726534850079682e-05
Column BILL_AMT20.8 split has improved loss 0.00010606002405513792
Column BILL_AMT30.2 split has improved loss 1.2923238124601388e-05
Column BILL_AMT30.4 split has improved loss 1.7703438644989244e-05
Column BILL_AMT30.6 split has improved loss 5.7255048681065235e-05
Column BILL_AMT30.8 split has improved loss 0.0001309855396841031
Column BILL_AMT40.2 split has improved loss 1.5757691495227322e-05
Column BILL_AMT40.4 split has improved loss 4.552825379744441e-05
Column BILL_AMT40.6 split has improved loss 1.5198874939420515e-05
Column BILL_AMT40.8 split has improved loss 5.0114289508379484e-05
Column BILL_AMT50.2 split has improved loss 2.5681591047099772e-05
Column BILL_AMT50.4 split has improved loss 4.704167867383702e-05
Column BILL_AMT50.6 split has improved loss 7.104706449934106e-06
Column BILL_AMT50.8 split has improved loss 4.3105571633450523e-05
Column BILL_AMT60.2 split has improved loss 1.1070212941444169e-05
Column BILL_AMT60.4 split has improved loss 7.89154272805015e-05
Column BILL_AMT60.6 split has improved loss 1.1814881573091185e-08
Column BILL_AMT60.8 split has improved loss 2.5239692683282078e-05
Column PAY_AMT10.2 split has improved loss 0.0037338214207645604
Column PAY_AMT10.4 split has improved loss 0.0025958239617459855
Column PAY_AMT10.6 split has improved loss 0.0026823429188202186
Column PAY_AMT10.8 split has improved loss 0.0018958312033806879
Column PAY_AMT20.2 split has improved loss 0.0026740961203186864
Column PAY_AMT20.4 split has improved loss 0.002413157179305403
Column PAY_AMT20.6 split has improved loss 0.0022468726365486025
Column PAY_AMT20.8 split has improved loss 0.001975166038548831
Column PAY_AMT30.2 split has improved loss 0.002428229767294232
Column PAY_AMT30.4 split has improved loss 0.0017518742462117598
Column PAY_AMT30.6 split has improved loss 0.0019701352550884066
Column PAY_AMT30.8 split has improved loss 0.001262546288527816
Column PAY_AMT40.2 split has improved loss nan
Column PAY_AMT40.4 split has improved loss 0.0016828468720352097
Column PAY_AMT40.6 split has improved loss 0.001992080354384307
Column PAY_AMT40.8 split has improved loss 0.001509125693650304
Column PAY_AMT50.2 split has improved loss nan
Column PAY_AMT50.4 split has improved loss 0.0013987996480574194
Column PAY_AMT50.6 split has improved loss 0.0015722436951876306
Column PAY_AMT50.8 split has improved loss 0.0013770421197525917
Column PAY_AMT60.2 split has improved loss nan
Column PAY_AMT60.4 split has improved loss 0.0016595462220761747
Column PAY_AMT60.6 split has improved loss 0.0019163958366781864
Column PAY_AMT60.8 split has improved loss 0.0017002531220795536

Out[278]:
('PAY_00.8', 0.14999327547679009)
In [279]:
find_split(train_x_t, train_y, bdd_cross_entropy, verbose = 0)

Out[279]:
('PAY_00.8', 0.29117246034455752)
In [280]:
find_split(train_x_t, train_y, acc, verbose = 0)

Out[280]:
(None, 0.77688000000000001)
In [281]:
np.mean(train_y[train_x_t['PAY_00.8']])

Out[281]:
0.19974715549936789
In [283]:
np.mean(train_y[~train_x_t['PAY_00.8']])

Out[283]:
0.13956871914783062
In [284]:
np.mean(train_y[train_x_t['AGE0.2']])

Out[284]:
0.21648985128130602
In [285]:
np.mean(train_y[~train_x_t['AGE0.2']])

Out[285]:
0.25453293550608219