from typing import List from transformers import pipeline from pyvis.network import Network from functools import lru_cache import spacy import streamlit as st DEFAULT_LABEL_COLORS = { "ORG": "#7aecec", "PRODUCT": "#bfeeb7", "GPE": "#feca74", "LOC": "#ff9561", "PERSON": "#aa9cfc", "NORP": "#c887fb", "FACILITY": "#9cc9cc", "EVENT": "#ffeb80", "LAW": "#ff8197", "LANGUAGE": "#ff8197", "WORK_OF_ART": "#f0d0ff", "DATE": "#bfe1d9", "TIME": "#bfe1d9", "MONEY": "#e4e7d2", "QUANTITY": "#e4e7d2", "ORDINAL": "#e4e7d2", "CARDINAL": "#e4e7d2", "PERCENT": "#e4e7d2", } def generate_knowledge_graph(texts: List[str], filename: str): nlp = load_spacy() doc = nlp("\n".join(texts).lower()) NERs = [ent.text for ent in doc.ents] NER_types = [ent.label_ for ent in doc.ents] triplets = [] for triplet in texts: triplets.extend(generate_partial_graph(triplet)) heads = [t["head"].lower() for t in triplets] tails = [t["tail"].lower() for t in triplets] nodes = list(set(heads + tails)) net = Network(directed=True, width="700px", height="700px") for n in nodes: if n in NERs: NER_type = NER_types[NERs.index(n)] if NER_type in NER_types: if NER_type in DEFAULT_LABEL_COLORS.keys(): color = DEFAULT_LABEL_COLORS[NER_type] else: color = "#666666" net.add_node(n, title=NER_type, shape="circle", color=color) else: net.add_node(n, shape="circle") else: net.add_node(n, shape="circle") unique_triplets = set() def stringify_trip(x): return x["tail"] + x["head"] + x["type"].lower() for triplet in triplets: if stringify_trip(triplet) not in unique_triplets: net.add_edge(triplet["head"].lower(), triplet["tail"].lower(), title=triplet["type"], label=triplet["type"]) unique_triplets.add(stringify_trip(triplet)) net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show(filename) return nodes @lru_cache(maxsize=16) def generate_partial_graph(text: str): triplet_extractor = pipeline( 'text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large' ) triples = triplet_extractor( text, return_tensors=True, return_text=False) if len(triples) == 0: return [] a = [triples[0]["generated_token_ids"]] extracted_text = triplet_extractor.tokenizer.batch_decode(a) extracted_triplets = extract_triplets(extracted_text[0]) return extracted_triplets def extract_triplets(text): """ Function to parse the generated text and extract the triplets """ triplets = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' for token in text.replace("", "").replace("", "").replace("", "").split(): if token == "": current = 't' if relation != '': triplets.append( {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append( {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': triplets.append( {'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()}) return triplets if __name__ == "__main__": generate_knowledge_graph( ["The dog is happy", "The cat is sad"], "test.html")