File size: 3,417 Bytes
1abc3a3
2177393
1abc3a3
fe86392
1abc3a3
c7109aa
 
 
68474cd
 
 
 
 
 
c7109aa
 
1abc3a3
c7109aa
1abc3a3
c7109aa
1abc3a3
c7109aa
1abc3a3
c7109aa
1abc3a3
 
 
 
 
c7109aa
1abc3a3
c7109aa
1abc3a3
 
c7109aa
1abc3a3
c7109aa
1abc3a3
c7109aa
 
 
 
 
1abc3a3
c7109aa
1abc3a3
c7109aa
1abc3a3
 
 
c7109aa
1abc3a3
 
c7109aa
 
1abc3a3
 
c7109aa
1abc3a3
c7109aa
1abc3a3
 
 
 
 
 
 
 
c7109aa
1abc3a3
c7109aa
 
 
1abc3a3
 
c7109aa
1abc3a3
 
c7109aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import os
from streamlit_chat import message
from langchain_groq import ChatGroq
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryMemory
from transformers import pipeline

from huggingface_hub import login

# Add your Hugging Face token here
HUGGINGFACE_TOKEN = os.getenv("HF")
login(token=HUGGINGFACE_TOKEN)

# Initialize the text classifier for guardrails
classifier = pipeline("text-classification", model="meta-llama/Prompt-Guard-86M")

# Set session state variables
if 'conversation' not in st.session_state:
    st.session_state['conversation'] = None
if 'messages' not in st.session_state:
    st.session_state['messages'] = []
if 'API_Key' not in st.session_state:
    st.session_state['API_Key'] = ''

# Setting page title and header
st.set_page_config(page_title="Chat GPT Clone", page_icon=":robot_face:")
st.markdown("<h1 style='text-align: center;'>How can I assist you? </h1>", unsafe_allow_html=True)

# Sidebar configuration
st.sidebar.title("😎")
groq_api_key = st.sidebar.text_input(label="Groq API Key", type="password")
summarise_button = st.sidebar.button("Summarise the conversation", key="summarise")
if summarise_button:
    st.sidebar.write("Nice chatting with you my friend ❤️:\n\n" + st.session_state['conversation'].memory.buffer)

# Function to get response from the chatbot
def getresponse(userInput, api_key):
    # Classify the input using guardrails
    classification = classifier(userInput)[0]  # Get the first result
    if classification['label'] == "JAILBREAK":
        # If classified as Jailbreak, return a predefined safe response
        return "You are attempting jailbreak/prompt injection. I can't help you with that. Please ask another question."

    # Initialize the conversation chain if not already initialized
    if st.session_state['conversation'] is None:
        llm = ChatGroq(model="Gemma2-9b-It", groq_api_key=groq_api_key)
        st.session_state['conversation'] = ConversationChain(
            llm=llm,
            verbose=True,
            memory=ConversationSummaryMemory(llm=llm),
        )

    # Generate a response using the conversation chain
    response = st.session_state['conversation'].predict(input=userInput)
    return response

# Response container
response_container = st.container()
# User input container
container = st.container()

with container:
    with st.form(key='my_form', clear_on_submit=True):
        user_input = st.text_area("Your question goes here:", key='input', height=100)
        submit_button = st.form_submit_button(label='Send')

        if submit_button:
            # Append user input to message history
            st.session_state['messages'].append(user_input)
            # Get response from the chatbot or guardrails
            model_response = getresponse(user_input, st.session_state['API_Key'])
            # Append model response to message history
            st.session_state['messages'].append(model_response)

            # Display the conversation
            with response_container:
                for i in range(len(st.session_state['messages'])):
                    if (i % 2) == 0:
                        message(st.session_state['messages'][i], is_user=True, key=str(i) + '_user')
                    else:
                        message(st.session_state['messages'][i], key=str(i) + '_AI')