본문 바로가기

대외활동/ABC 지역주도형 청년 취업역량강화 ESG 지원산업

[ABC 220923] 지도 학습 알고리즘 - 결정 트리

반응형

결정 트리

  • 결정 트리 decision tree는 분류와 회귀 문제에 널리 사용하는 모델
    • 기본적으로 결정 트리는 결정에 다다르기 위해 예/아니오 질문을 이어 나가면서 학습
    • 모델을 직접 만드는 대신 지도 학습 방식으로 데이터로부터 학습
      • 맨 위 노드 : 루트 노드
      • 특히, 마지막 노드는 리프 leaf 라고도 함
      • 트리의 노드는 질문이나 정답을 담는 네모 상자
      • 엣지 edge는 질문의 답과 다음 질문을 연결

결정 트리 - 결정 트리의 복잡도 제어하기

  • 결정 트리의 복잡도 제어하기
    • 일반적으로 트리 만들기를 모든 리프 노드가 순수 노드가 될 때까지 진행하면 모델이 매우 복잡해지고 훈련 데이터에 과대 적합 됨 → 순수 노드로 이루어진 트리는 훈련 세트에 100% 정확하게 맞는다는 의미
  • 과대 적합을 막는 전략은 크게 두가지
    • 트리 생성은 일찍 중단하는 전략 (사전 가지치기 pre-pruning)
    • 트리를 만든 후 데이터 포인트가 적은 노드를 삭제하거나 병합하는 전략 (사후 가지치기 post-pruning) 또는 그냥 가지치기 (pruning)
    • scikit-learn에서 결정 트리는 DecisionTreeRefressor 와 DecisionTreeClassifier에 구현
    • scikit-learn은 사전 가지치기만 지원
  • 결정 트리 분석
    • 트리 모듈의 export_graphviz함수를 이용해 트리를 시각화
    • 트리를 시각화하면 알고리즘의 예측이 어떻게 이뤄지는지 잘 이해할 수 있으면 비전문가에게 머신러닝 알고리즘을 설명하기에 좋음
    • 단점 사전 가지치기를 해도 과대 적합 경향이 있음

결정 트리 - DecisionTreeClassifier 예제

  • DecisionTreeClassifier 사용하여 Breast Cancer 데이터셋을 사용하여 유방암(1), 정상(0) 예측 및 성능 평가
    • Breast Cancer 데이터셋 (출처 : UCI ML Repository)
    • Breast Cancer 데이터 전처리 (데이터 타입 변환 및 결측치 제거, 설명 변수 데이터를 정규화)
    • 데이터셋 분리 (훈련 셋, 테스트 셋)
      • → Decision Tree 분류 모형 - sklearn 사용
      • → 속성으로는 criterion = 'entropy',max_depth=5 적용
    • max_depth 설정 값을 변경하며, 최적의 모델 찾기

DecisionTreeClassifier 유방암 예측

문제 정의 : DecisionTreeClassifier를 사용하여 유방암 양성(2), 악성(4) 예측

기본 라이브러리 임포트

from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics

import pandas as pd
import numpy as np

한글 깨짐 방지

import matplotlib as mpl
import matplotlib.pyplot as plt
 
%config InlineBackend.figure_format = 'retina'
 
!apt -qq -y install fonts-nanum
 
import matplotlib.font_manager as fm
fontpath = '/usr/share/fonts/truetype/nanum/NanumBarunGothic.ttf'
font = fm.FontProperties(fname=fontpath, size=9)
plt.rc('font', family='NanumBarunGothic') 
mpl.font_manager._rebuild()

데이터 준비하기

# 외부에 있는 파일을 링크로 가져오기
# UCI ML Repository 제공하는 Breast Cancer 데이터셋 가져오기
uci_path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/\
breast-cancer-wisconsin/breast-cancer-wisconsin.data'

df = pd.read_csv(uci_path, header=None)

# 열(컬럼) 이름을 지정하자
df.columns = ['id', 'clump', 'cell_size', 'cell_shape', 'adhesion', 'epithlial',
              'bare_nuclei', 'chromatin', 'normal_nucleoli', 'mitoses', 'class']

df.head()

데이터 확인하기

df.info()

# bare_nuclei 컬럼의 자료형이 object -> int
# 1) unique()함수를 이용해서 bare_nuclei 컬럼의 고유값을 확인
df['bare_nuclei'].unique()

# 2) '?' -> np.nan으로 변경하고 수를 확인하자 -> 16개의 물음표를 np.nan으로 변환
df['bare_nuclei'].replace('?', np.nan, inplace=True)
df['bare_nuclei'].isna().sum()

# 3) NaN 데이터 삭제 하고 전체 수 확인
df.dropna(subset=['bare_nuclei'], axis=0, inplace=True)
df.info()

# 4) bare_nuclei 컬럼의 형변환 object -> int
df['bare_nuclei'] = df['bare_nuclei'].astype('int')
df.info()

데이터 분리하기

X = df[['clump', 'cell_size', 'cell_shape', 'adhesion', 'epithlial','bare_nuclei', 'chromatin', 'normal_nucleoli', 'mitoses']]
y = df['class']
X

# X 독립변수 데이터를 정규화
X = preprocessing.StandardScaler().fit(X).transform(X)
X

# train, test set 분리 (7:3)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=7)

print('X_train.shape : ', X_train.shape)
print('X_test.shape : ', X_test.shape)

DecisionTree 분류 모델 설정

# 모델 객체 생성 (최적의 속성을 찾기 위해서 criterion='entropy' 적용) + 적정한 레벨 값을 찾는 것이 중요
tree_model = tree.DecisionTreeClassifier(criterion='entropy', max_depth=5)

모델 학습하기

tree_model.fit(X_train, y_train)

모델 학습하기

y_pred = tree_model.predict(X_test)

모델 성능 평가

print('훈련 세트 정확도 : {:.2f}%'.format(tree_model.score(X_train, y_train)*100))
print('테스트 세트 정확도 : {:.2f}%'.format(tree_model.score(X_test, y_test)*100))

tree_report = metrics.classification_report(y_test, y_pred)
print(tree_report)

결정 트리 그래프 그리기

from sklearn.tree import export_graphviz
export_graphviz(tree_model, out_file='tree.dot', class_names=['악성', '양성'], 
                feature_names=df.columns[1:10], impurity=False, filled=True)
import graphviz

with open('tree.dot') as f :
  dot_graph = f.read()
display(graphviz.Source(dot_graph))

LIST