Creating the Conda environment and the ipykernel.
Refer this link: Creating environment and kernel for sentence_transformersCode
import pandas as pd import numpy as np from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer import warnings warnings.filterwarnings('ignore') df1 = pd.read_csv('bbc_news_train.csv') Note: You can download the BBC News data set from this GitHub repository: bbc_news df1.columns Index(['ArticleId', 'Text', 'Category'], dtype='object') SentenceTransformer._version 1 model = SentenceTransformer('distilbert-base-nli-mean-tokens') %%time sentences = df1['Text'].values.tolist() sentences_df = pd.DataFrame({"sentences": sentences}) tokenized = sentences_df['sentences'].apply(lambda x: model.encode([x])[0]) CPU times: total: 11min 6s Wall time: 5min 35s tokenized_list = tokenized.values.tolist() cols = ["feature_" + str(i) for i in range(768)] Note: DistilBERT based SentenceTransformer produces 768 features. tokenized_df = pd.DataFrame(columns = cols, data = tokenized_list) import numpy as np from matplotlib import pyplot as plt from scipy.cluster.hierarchy import dendrogram from sklearn.datasets import load_iris from sklearn.cluster import AgglomerativeClustering def plot_dendrogram(model, **kwargs): # Create linkage matrix and then plot the dendrogram # create the counts of samples under each node counts = np.zeros(model.children_.shape[0]) n_samples = len(model.labels_) for i, merge in enumerate(model.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack( [model.children_, model.distances_, counts] ).astype(float) # Plot the corresponding dendrogram dendrogram(linkage_matrix, **kwargs) X = tokenized_df # setting distance_threshold=0 ensures we compute the full tree. model = AgglomerativeClustering(distance_threshold=0, n_clusters=None) model = model.fit(X) plt.title("Hierarchical Clustering Dendrogram") # plot the top three levels of the dendrogram plot_dendrogram(model, truncate_mode="level", p=3) plt.xlabel("Number of points in node (or index of point if no parenthesis).") plt.show()
Thursday, June 2, 2022
Creating a Taxonomy for BBC News Articles (Part 1 (20220602))
Subscribe to:
Post Comments (Atom)
No comments:
Post a Comment