File size: 3,382 Bytes
1abc3a3 2177393 1abc3a3 fe86392 1abc3a3 c7109aa 68474cd 26ea18c c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 469b8a3 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa 1abc3a3 c7109aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import streamlit as st
import os
from streamlit_chat import message
from langchain_groq import ChatGroq
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryMemory
from transformers import pipeline
from huggingface_hub import login
# Add your Hugging Face token here
HUGGINGFACE_TOKEN = os.getenv("HF")
login(token=HUGGINGFACE_TOKEN)
groq_api_key = os.getenv("GROQ_API_KEY")
# Initialize the text classifier for guardrails
classifier = pipeline("text-classification", model="meta-llama/Prompt-Guard-86M")
# Set session state variables
if 'conversation' not in st.session_state:
st.session_state['conversation'] = None
if 'messages' not in st.session_state:
st.session_state['messages'] = []
if 'API_Key' not in st.session_state:
st.session_state['API_Key'] = ''
# Setting page title and header
st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)
# Sidebar configuration
st.sidebar.title("😎")
summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
if summarise_button:
st.sidebar.write("Nice chatting with you my friend ❤️:\n\n" + st.session_state['conversation'].memory.buffer)
# Function to get response from the chatbot
def getresponse(userInput, api_key):
# Classify the input using guardrails
classification = classifier(userInput)[0] # Get the first result
if classification['label'] == "JAILBREAK":
# If classified as Jailbreak, return a predefined safe response
return "You are attempting jailbreak/prompt injection. I can't help you with that. Please ask another question."
# Initialize the conversation chain if not already initialized
if st.session_state['conversation'] is None:
llm = ChatGroq(model="Gemma2-9b-It", groq_api_key=groq_api_key)
st.session_state['conversation'] = ConversationChain(
llm=llm,
verbose=True,
memory=ConversationSummaryMemory(llm=llm),
)
# Generate a response using the conversation chain
response = st.session_state['conversation'].predict(input=userInput)
return response
# Response container
response_container = st.container()
# User input container
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("Your question goes here:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button:
# Append user input to message history
st.session_state['messages'].append(user_input)
# Get response from the chatbot or guardrails
model_response = getresponse(user_input, st.session_state['API_Key'])
# Append model response to message history
st.session_state['messages'].append(model_response)
# Display the conversation
with response_container:
for i in range(len(st.session_state['messages'])):
if (i % 2) == 0:
message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
else:
message(st.session_state['messages'][i], key=str(i) + '_AI')
|