Note
In this part, we cheat a little bit, look at the data set for number of categories that have been already defined for the news articles and then we apply SentenceTransformer to encode and AgglomerativeClustering to identify hierarchical groups for these articles. import pandas as pd import numpy as np from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer from collections import Counter import warnings warnings.filterwarnings('ignore') from matplotlib import pyplot as plt from scipy.cluster.hierarchy import dendrogram from sklearn.cluster import AgglomerativeClustering df1 = pd.read_csv('bbc_news_train.csv') df1.columns Index(['ArticleId', 'Text', 'Category'], dtype='object')Peeking into the data set
Counter(df1['Category']) Counter({'business': 336, 'tech': 261, 'politics': 274, 'sport': 346, 'entertainment': 273}) model = SentenceTransformer('distilbert-base-nli-mean-tokens') %%time sentences = df1['Text'].values.tolist() sentences_df = pd.DataFrame({"sentences": sentences}) tokenized = sentences_df['sentences'].apply(lambda x: model.encode([x])[0]) CPU times: total: 10min 56s Wall time: 5min 29s tokenized_list = tokenized.values.tolist() cols = ["feature_" + str(i) for i in range(768)] tokenized_df = pd.DataFrame(columns = cols, data = tokenized_list) X = tokenized_df # setting distance_threshold=0 ensures we compute the full tree. # model = AgglomerativeClustering(distance_threshold=0, n_clusters=None) # Setting number of clusters to 5 as per the categories present in the training data set. # model = AgglomerativeClustering(distance_threshold=0, n_clusters=5) # ValueError: Exactly one of n_clusters and distance_threshold has to be set, and the other needs to be None. # By default, parameter "compute_distances" is False and in such case, we would not be able to produce Dendrogram so we set this parameter to True. model = AgglomerativeClustering(distance_threshold=None, n_clusters=5, compute_distances=True) pred = model.fit_predict(X) print(pred) print(len(pred)) print(len(Counter(pred).keys())) [0 1 3 ... 1 3 3] 1490 5 def plot_dendrogram(model, **kwargs): # Create linkage matrix and then plot the dendrogram # create the counts of samples under each node counts = np.zeros(model.children_.shape[0]) n_samples = len(model.labels_) for i, merge in enumerate(model.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack( [model.children_, model.distances_, counts] ).astype(float) # Plot the corresponding dendrogram dendrogram(linkage_matrix, **kwargs) return linkage_matrix df1['pred'] = pred df1.to_csv('output_' + str(round(time())) + ".csv", index = False) plt.title("Hierarchical Clustering Dendrogram") # plot the top three levels of the dendrogram linkage_matrix = plot_dendrogram(model, truncate_mode="level", p=3) plt.xlabel("Number of points in node (or index of point if no parenthesis).") plt.show()Analyzing the output in Excel
Table View
Now in Pivot Table
Thursday, June 2, 2022
Creating a Taxonomy for BBC News Articles (Part 2)
Subscribe to:
Post Comments (Atom)
No comments:
Post a Comment