Knowledge-graphs / rebel.py
khaerens's picture
Merge branch 'main' of https://huggingface.co/spaces/ml6team/Knowledge-graphs
154538c
raw
history blame
4.34 kB
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
import streamlit as st
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
def generate_knowledge_graph(texts: List[str], filename: str):
nlp = load_spacy()
doc = nlp("\n".join(texts).lower())
NERs = [ent.text for ent in doc.ents]
NER_types = [ent.label_ for ent in doc.ents]
triplets = []
for triplet in texts:
triplets.extend(generate_partial_graph(triplet))
heads = [t["head"].lower() for t in triplets]
tails = [t["tail"].lower() for t in triplets]
nodes = list(set(heads + tails))
net = Network(directed=True, width="700px", height="700px")
for n in nodes:
if n in NERs:
NER_type = NER_types[NERs.index(n)]
if NER_type in NER_types:
if NER_type in DEFAULT_LABEL_COLORS.keys():
color = DEFAULT_LABEL_COLORS[NER_type]
else:
color = "#666666"
net.add_node(n, title=NER_type, shape="circle", color=color)
else:
net.add_node(n, shape="circle")
else:
net.add_node(n, shape="circle")
unique_triplets = set()
def stringify_trip(x): return x["tail"] + x["head"] + x["type"].lower()
for triplet in triplets:
if stringify_trip(triplet) not in unique_triplets:
net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
title=triplet["type"], label=triplet["type"])
unique_triplets.add(stringify_trip(triplet))
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
return nodes
@lru_cache(maxsize=16)
def generate_partial_graph(text: str):
triplet_extractor = pipeline(
'text2text-generation',
model='Babelscape/rebel-large',
tokenizer='Babelscape/rebel-large'
)
triples = triplet_extractor(
text,
return_tensors=True,
return_text=False)
if len(triples) == 0:
return []
a = [triples[0]["generated_token_ids"]]
extracted_text = triplet_extractor.tokenizer.batch_decode(a)
extracted_triplets = extract_triplets(extracted_text[0])
return extracted_triplets
def extract_triplets(text):
"""
Function to parse the generated text and extract the triplets
"""
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append(
{'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip()})
return triplets
if __name__ == "__main__":
generate_knowledge_graph(
["The dog is happy", "The cat is sad"], "test.html")