메인 콘텐츠로 건너뛰기
단 몇 줄의 코드로 wandb를 사용하여 scikit-learn 모델의 성능을 시각화하고 비교할 수 있습니다. 예제 실행하기 →

시작하기

가입 및 API 키 생성

API 키는 W&B에 사용자 머신을 인증하는 데 사용됩니다. 사용자 프로필에서 API 키를 생성할 수 있습니다.
For a more streamlined approach, create an API key by going directly to User Settings. Copy the newly created API key immediately and save it in a secure location such as a password manager.
  1. 오른쪽 상단에 있는 사용자 프로필 아이콘을 클릭합니다.
  2. User Settings를 선택한 다음 API Keys 섹션으로 스크롤합니다.

wandb 라이브러리 설치 및 로그인

로컬에 wandb 라이브러리를 설치하고 로그인하려면 다음을 수행하세요:
  1. WANDB_API_KEY 환경 변수를 생성한 API 키로 설정합니다.
    export WANDB_API_KEY=<your_api_key>
    
  2. wandb 라이브러리를 설치하고 로그인합니다.
    pip install wandb
    
    wandb login
    

메트릭 로그

import wandb

wandb.init(project="visualize-sklearn") as run:

  y_pred = clf.predict(X_test)
  accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)

  # 시간에 따른 메트릭을 로그하려면 run.log를 사용합니다.
  run.log({"accuracy": accuracy})

  # 또는 트레이닝 종료 시 최종 메트릭을 로그하려면 run.summary를 사용할 수도 있습니다.
  run.summary["accuracy"] = accuracy

차트 그리기

1단계: wandb 임포트 및 새로운 run 초기화

import wandb

run = wandb.init(project="visualize-sklearn")

2단계: 차트 시각화

개별 차트

모델을 트레이닝하고 예측을 수행한 후, wandb에서 차트를 생성하여 예측값을 분석할 수 있습니다. 지원되는 전체 차트 목록은 아래의 지원되는 차트 섹션을 참조하세요.
# 단일 차트 시각화
wandb.sklearn.plot_confusion_matrix(y_true, y_pred, labels)

모든 차트

W&B에는 여러 관련 차트를 한 번에 그려주는 plot_classifier와 같은 함수가 있습니다:
# 모든 분류기(classifier) 차트 시각화
wandb.sklearn.plot_classifier(
    clf,
    X_train,
    X_test,
    y_train,
    y_test,
    y_pred,
    y_probas,
    labels,
    model_name="SVC",
    feature_names=None,
)

# 모든 회귀(regression) 차트 시각화
wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test, model_name="Ridge")

# 모든 클러스터링 차트 시각화
wandb.sklearn.plot_clusterer(
    kmeans, X_train, cluster_labels, labels=None, model_name="KMeans"
)

run.finish()

기존 Matplotlib 차트

Matplotlib으로 생성한 차트도 W&B 대시보드에 로그할 수 있습니다. 이를 위해서는 먼저 plotly를 설치해야 합니다.
pip install plotly
마지막으로 다음과 같이 W&B 대시보드에 차트를 로그할 수 있습니다:
import matplotlib.pyplot as plt
import wandb

with wandb.init(project="visualize-sklearn") as run:

  # 모든 plt.plot(), plt.scatter() 등을 여기서 수행합니다.
  # ...

  # plt.show() 대신 다음을 실행합니다:
  run.log({"plot": plt})

지원되는 차트

학습 곡선 (Learning curve)

Scikit-learn learning curve
다양한 길이의 데이터셋에서 모델을 트레이닝하고, 트레이닝 세트와 테스트 세트 모두에 대해 데이터셋 크기 대비 교차 검증 점수 차트를 생성합니다. wandb.sklearn.plot_learning_curve(model, X, y)
  • model (clf 또는 reg): 학습된 회귀 모델 또는 분류기를 인수로 받습니다.
  • X (arr): 데이터셋 피처.
  • y (arr): 데이터셋 라벨.

ROC

Scikit-learn ROC curve
ROC 곡선은 재현율(TPR, y축) 대비 위양성률(FPR, x축)을 나타냅니다. 이상적인 점수는 왼쪽 상단 지점인 TPR = 1, FPR = 0입니다. 일반적으로 ROC 곡선 아래 면적(AUC-ROC)을 계산하며, AUC-ROC가 클수록 성능이 좋습니다. wandb.sklearn.plot_roc(y_true, y_probas, labels)
  • y_true (arr): 테스트 세트 라벨.
  • y_probas (arr): 테스트 세트 예측 확률.
  • labels (list): 타겟 변수(y)의 이름이 지정된 라벨 목록.

클래스 비율 (Class proportions)

Scikit-learn classification properties
트레이닝 및 테스트 세트에서 타겟 클래스의 분포를 그립니다. 불균형 클래스를 탐지하고 특정 클래스가 모델에 불균형한 영향을 미치지 않는지 확인하는 데 유용합니다. wandb.sklearn.plot_class_proportions(y_train, y_test, ['dog', 'cat', 'owl'])
  • y_train (arr): 트레이닝 세트 라벨.
  • y_test (arr): 테스트 세트 라벨.
  • labels (list): 타겟 변수(y)의 이름이 지정된 라벨 목록.

PR 곡선 (Precision recall curve)

Scikit-learn precision-recall curve
다양한 임계값에 대해 정밀도(precision)와 재현율(recall) 사이의 트레이드오프를 계산합니다. 곡선 아래 면적이 넓을수록 높은 재현율과 높은 정밀도를 모두 나타내며, 높은 정밀도는 낮은 위양성률과 관련이 있고 높은 재현율은 낮은 위음성률과 관련이 있습니다. 두 항목 모두에서 높은 점수는 분류기가 정확한 결과(높은 정밀도)를 반환할 뿐만 아니라 모든 양성 결과의 대부분(높은 재현율)을 반환하고 있음을 보여줍니다. PR 곡선은 클래스가 매우 불균형할 때 유용합니다. wandb.sklearn.plot_precision_recall(y_true, y_probas, labels)
  • y_true (arr): 테스트 세트 라벨.
  • y_probas (arr): 테스트 세트 예측 확률.
  • labels (list): 타겟 변수(y)의 이름이 지정된 라벨 목록.

피처 중요도 (Feature importances)

Scikit-learn feature importance chart
분류 작업에서 각 피처의 중요도를 평가하고 차트로 그립니다. 트리 계열 모델과 같이 feature_importances_ 속성을 가진 분류기에서만 작동합니다. wandb.sklearn.plot_feature_importances(model, ['width', 'height, 'length'])
  • model (clf): 학습된 분류기를 인수로 받습니다.
  • feature_names (list): 피처의 이름 목록. 피처 인덱스를 해당 이름으로 교체하여 차트를 더 읽기 쉽게 만듭니다.

보정 곡선 (Calibration curve)

Scikit-learn calibration curve
분류기의 예측 확률이 얼마나 잘 보정되었는지, 그리고 보정되지 않은 분류기를 보정하는 방법을 시각화합니다. 베이스라인 로지스틱 회귀 모델, 인수로 전달된 모델, 그리고 해당 모델의 등조성(isotonic) 보정과 시그모이드(sigmoid) 보정에 의해 추정된 예측 확률을 비교합니다. 보정 곡선이 대각선에 가까울수록 좋습니다. 전치된 시그모이드 형태의 곡선은 과적합된 분류기를 나타내며, 시그모이드 형태의 곡선은 과소적합된 분류기를 나타냅니다. 모델의 등조성 및 시그모이드 보정을 트레이닝하고 곡선을 비교함으로써 모델의 과적합 또는 과소적합 여부를 파악하고, 어떤 보정(시그모이드 또는 등조성)이 이를 해결하는 데 도움이 될지 결정할 수 있습니다. 자세한 내용은 sklearn 문서를 확인하세요. wandb.sklearn.plot_calibration_curve(clf, X, y, 'RandomForestClassifier')
  • model (clf): 학습된 분류기를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.
  • y (arr): 트레이닝 세트 라벨.
  • model_name (str): 모델 이름. 기본값은 ‘Classifier’입니다.

혼동 행렬 (Confusion matrix)

Scikit-learn confusion matrix
분류의 정확도를 평가하기 위해 혼동 행렬을 계산합니다. 모델 예측의 품질을 평가하고 모델이 틀리는 예측 패턴을 찾는 데 유용합니다. 대각선은 실제 라벨과 예측 라벨이 일치하는, 즉 모델이 맞게 예측한 경우를 나타냅니다. wandb.sklearn.plot_confusion_matrix(y_true, y_pred, labels)
  • y_true (arr): 테스트 세트 라벨.
  • y_pred (arr): 테스트 세트 예측 라벨.
  • labels (list): 타겟 변수(y)의 이름이 지정된 라벨 목록.

요약 메트릭 (Summary metrics)

Scikit-learn summary metrics
  • 분류를 위해 mse, mae, r2 점수와 같은 요약 메트릭을 계산합니다.
  • 회귀를 위해 f1, 정확도, 정밀도, 재현율과 같은 요약 메트릭을 계산합니다.
wandb.sklearn.plot_summary_metrics(model, X_train, y_train, X_test, y_test)
  • model (clf 또는 reg): 학습된 회귀 모델 또는 분류기를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.
  • y (arr): 트레이닝 세트 라벨.
    • X_test (arr): 테스트 세트 피처.
  • y_test (arr): 테스트 세트 라벨.

엘보우 차트 (Elbow plot)

Scikit-learn elbow plot
클러스터 수에 따른 설명된 분산의 백분율과 트레이닝 시간을 측정하고 시각화합니다. 최적의 클러스터 수를 선택하는 데 유용합니다. wandb.sklearn.plot_elbow_curve(model, X_train)
  • model (clusterer): 학습된 클러스터를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.

실루엣 차트 (Silhouette plot)

Scikit-learn silhouette plot
한 클러스터의 각 점이 인접한 클러스터의 점들과 얼마나 가까운지 측정하고 시각화합니다. 클러스터의 두께는 클러스터 크기에 해당합니다. 세로선은 모든 점의 평균 실루엣 점수를 나타냅니다. +1에 가까운 실루엣 계수는 샘플이 인접한 클러스터에서 멀리 떨어져 있음을 나타냅니다. 0은 샘플이 두 인접한 클러스터 사이의 결정 경계에 있거나 매우 가까움을 나타내며, 음수 값은 해당 샘플이 잘못된 클러스터에 할당되었을 수 있음을 나타냅니다. 일반적으로 모든 실루엣 클러스터 점수가 평균(빨간색 선 위)보다 높고 1에 최대한 가깝기를 원합니다. 또한 데이터의 내재된 패턴을 반영하는 클러스터 크기를 선호합니다. wandb.sklearn.plot_silhouette(model, X_train, ['spam', 'not spam'])
  • model (clusterer): 학습된 클러스터를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.
    • cluster_labels (list): 클러스터 라벨의 이름 목록. 클러스터 인덱스를 해당 이름으로 교체하여 차트를 더 읽기 쉽게 만듭니다.

이상치 후보 차트 (Outlier candidates plot)

Scikit-learn outlier plot
쿡의 거리(Cook’s distance)를 통해 데이터 포인트가 회귀 모델에 미치는 영향력을 측정합니다. 영향력이 크게 치우친 인스턴스는 잠재적인 이상치일 수 있습니다. 이상치 탐지에 유용합니다. wandb.sklearn.plot_outlier_candidates(model, X, y)
  • model (regressor): 학습된 분류기를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.
  • y (arr): 트레이닝 세트 라벨.

잔차 차트 (Residuals plot)

Scikit-learn residuals plot
예측된 타겟 값(y축) 대비 실제 값과 예측된 타겟 값의 차이(x축), 그리고 잔차 오차의 분포를 측정하고 시각화합니다. 일반적으로 잘 적합된 모델의 잔차는 무작위로 분포되어야 합니다. 좋은 모델은 무작위 오차를 제외한 데이터셋의 대부분의 현상을 설명할 수 있기 때문입니다. wandb.sklearn.plot_residuals(model, X, y)
  • model (regressor): 학습된 분류기를 인수로 받습니다.
  • X (arr): 트레이닝 세트 피처.
  • y (arr): 트레이닝 세트 라벨. 질문이 있으시면 언제든지 Slack 커뮤니티에서 문의해 주세요.

예제