Função “plot_tree”

A função plot_tree do módulo sklearn.tree é utilizada para visualizar a estrutura de uma árvore de decisão criada com o algoritmo DecisionTreeClassifier ou DecisionTreeRegressor do scikit-learn. Essa função permite que os usuários observem graficamente como a árvore foi construída e como os dados são divididos em cada nó da árvore.

Sintaxe

from sklearn.tree import plot_tree
 
plot_tree(decision_tree, feature_names=None, class_names=None, filled=True, rounded=True, impurity=True)

Parâmetros:

  • decision_tree: O modelo de árvore de decisão criado pelo DecisionTreeClassifier ou DecisionTreeRegressor que se deseja visualizar.

  • feature_names (opcional): Uma lista contendo os nomes das características (atributos) usados na árvore. Se não for fornecido, serão usados nomes genéricos como “feature_0”, “feature_1”, etc.

  • class_names (opcional): Uma lista contendo os nomes das classes de saída do modelo, usadas em problemas de classificação. Se não for fornecido, serão usados valores numéricos para representar as classes.

  • filled (opcional): Um valor booleano que indica se os nós da árvore serão coloridos para mostrar a classe de destino mais frequente para problemas de classificação ou a média dos valores de destino para problemas de regressão.

  • rounded (opcional): Um valor booleano que indica se os nós da árvore terão bordas arredondadas.

  • impurity (opcional): Um valor booleano que indica se a impureza dos nós será mostrada. A impureza é uma medida de quão misturados estão os dados em cada nó da árvore.

Exemplo

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
 
# Carregar o conjunto de dados de crédito
with open("credit.pkl", "rb") as file:
    X_credit_train, y_credit_train, X_credit_test, y_credit_test = pickle.load(file)
 
# Criar o modelo de árvore de decisão
credit_tree = DecisionTreeClassifier(criterion="entropy", random_state=0)
credit_tree.fit(X_credit_train, y_credit_train)
 
# Visualizar a árvore de decisão
forecasters = ["income", "age", "loan"]
 
fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(20, 20))
 
credit_class_names = [str(i) for i in credit_tree.classes_]
tree.plot_tree(credit_tree, feature_names=forecasters, class_names=credit_class_names, filled=True)

Saída:

No exemplo acima, carregamos o conjunto de dados de crédito, criamos um modelo de árvore de decisão e o ajustamos aos dados. Em seguida, utilizamos a função plot_tree para visualizar a árvore criada. O parâmetro feature_names é usado para fornecer os nomes das características do conjunto de dados, o parâmetro class_names é usado para fornecer os nomes das classes de saída e os parâmetros filled, rounded e impurity são usados para estilizar a visualização da árvore.

Conclusão

A visualização da árvore de decisão mostra como os dados são divididos em cada nó, quais as características são usadas para fazer as divisões e a classe de destino associada a cada nó terminal. Essa visualização é uma ferramenta útil para entender como a árvore toma decisões com base nos dados de entrada.