hAIring / utils /llm.py
plutoze's picture
update: add env variables support and secrets management
a41f76e
import time
import streamlit as st
from langchain.memory import ConversationBufferWindowMemory, ConversationBufferMemory, StreamlitChatMessageHistory
from langchain.prompts import MessagesPlaceholder
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import VoyageEmbeddings
from langchain.vectorstores import DeepLake
from langchain.retrievers.document_compressors import CohereRerank
from langchain.retrievers import ContextualCompressionRetriever
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from dotenv import load_dotenv
from operator import itemgetter
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
load_dotenv()
response_schemas = [
ResponseSchema(name="Name", description="The candidate name."),
ResponseSchema(name="Experience", description="The work experiences of the candidate."),
ResponseSchema(name="Skills", description="The skills of the candidate."),
ResponseSchema(name="Projects", description="The projects of the candidate."),
ResponseSchema(name="Summary", description="Final conclusion about the chosen candidate as the best choice")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
system_message = (
"You are a experienced Hiring manager. "
"You are looking for a relevant candidates to fill a position in your company. Always look for Job roles, "
"You have a list of requirements asked by the user. Look for major skills and projects in the document."
" Always look for relevant experiences, projects and skills. "
"and finally the link to the resume. IF the link is a PATH TO A FILE, then output only the the filename not the path."
)
human_message = ("You are provided with a list of resumes. "
"Find the most relevant candidates for the job from "
"{context} "
"that suits the job decription "
"{query}"
"Answer the query in a format that is easy to read."
"------------"
"{format_instructions}"
)
@st.cache_resource
def init_retriever(dataset_path="hub://p1utoze/default",
embeddings_model="voyage-lite-01",
):
embeddings = VoyageEmbeddings(model=embeddings_model, show_progress_bar=True)
db = DeepLake(token=st.secrets["deep_lake_token"], dataset_path=dataset_path, read_only=True, embedding=embeddings)
retriever = db.as_retriever()
retriever.search_kwargs['top_k'] = "5"
retriever.search_kwargs['distance_metric'] = 'cos'
compressor = CohereRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
return compression_retriever
@st.cache_resource
def load_model():
chat = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.4)
return chat
@st.cache_resource
def load_memory():
msgs = StreamlitChatMessageHistory(key="langchain_messages")
memory = ConversationBufferMemory(return_messages=True)
return memory
def typewriter(text: str, title: str, speed: int):
tokens = text.split()
container = st.empty()
for index in range(len(tokens) + 1):
curr_full_text = " ".join(tokens[:index])
if index == len(tokens):
container.markdown(curr_full_text)
else:
container.markdown(curr_full_text + "▌")
time.sleep(1 / speed)
def model_pipeline(memory):
compression_retriever = init_retriever(
"hub://p1utoze/resumes",
embeddings_model="voyage-lite-01",
)
chat = load_model()
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(
human_message,
partial_variables={"format_instructions": format_instructions},
input_variables=["query", "context"]
)
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt,
MessagesPlaceholder(variable_name="history"),
human_message_prompt]
)
print(memory.load_memory_variables({}))
setup_and_retrieval = RunnableParallel(
{"context": compression_retriever,
"query": RunnablePassthrough(),
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history")
},
)
chain = (
setup_and_retrieval
| prompt
| chat
| output_parser
)
return chain
# print(chain.invoke("My company are looking for RDBMS database designers with 3 years experience. He or she should be experienced in Java, SQL "))