emrecan commited on
Commit
ddcecdf
·
1 Parent(s): ca9d3af

first working app

Browse files
Files changed (3) hide show
  1. app.py +89 -38
  2. models.py +1 -1
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,51 +1,37 @@
 
1
  import pandas as pd
2
  import streamlit as st
3
  import plotly.express as px
4
  from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
 
5
 
6
- st.title("Zero-shot Turkish Text Classification")
 
7
 
8
- method_selection = st.radio(
9
- "Select a zero-shot classification method.",
10
- [
11
- METHOD_OPTIONS["nli"],
12
- METHOD_OPTIONS["nsp"],
13
- ],
14
- )
15
 
16
- if method_selection == METHOD_OPTIONS["nli"]:
17
- model = st.selectbox(
18
- "Select a natural language inference model.", NLI_MODEL_OPTIONS
19
- )
20
- if method_selection == METHOD_OPTIONS["nsp"]:
21
- model = st.selectbox(
22
- "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS
23
- )
24
 
25
- st.header("Configure prompts and labels")
26
- col1, col2 = st.columns(2)
27
 
28
- with col1:
29
- st.subheader("Candidate labels")
30
- labels = st.text_area(
31
- label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
32
- value="spor,dünya,siyaset,ekonomi,kültür ve sanat",
33
- )
34
- st.header("Make predictions")
35
- st.text_area("", value="Enter some text to classify.")
36
- st.button("Predict")
 
 
37
 
38
- with col2:
39
- st.subheader("Prompt template")
40
- prompt_template = st.text_area(
41
- label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
42
- value="Bu metin {} kategorisine aittir",
43
  )
44
- st.header("")
45
- probs = [0.86, 0.10, 0.01, 0.02, 0.01]
46
- data = pd.DataFrame(
47
- {"labels": labels.split(","), "probability": probs}
48
- ).sort_values(by="probability", ascending=False)
49
  chart = px.bar(
50
  data,
51
  x="probability",
@@ -67,4 +53,69 @@ with col2:
67
  "showlegend": False,
68
  }
69
  )
70
- st.plotly_chart(chart)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
  import pandas as pd
3
  import streamlit as st
4
  import plotly.express as px
5
  from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
6
+ from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
7
 
8
+ if "current_model" not in st.session_state:
9
+ st.session_state["current_model"] = None
10
 
11
+ if "current_model_option" not in st.session_state:
12
+ st.session_state["current_model_option"] = None
 
 
 
 
 
13
 
14
+ if "current_method_option" not in st.session_state:
15
+ st.session_state["current_method_option"] = None
 
 
 
 
 
 
16
 
 
 
17
 
18
+ def load_model(model_option: str, method_option: str, random_state: int = 0):
19
+ with st.spinner("Loading selected model..."):
20
+ if method_option == "Natural Language Inference":
21
+ st.session_state.current_model = NLIZeroshotClassifier(
22
+ model_name=model_option, random_state=random_state
23
+ )
24
+ else:
25
+ st.session_state.current_model = NSPZeroshotClassifier(
26
+ model_name=model_option, random_state=random_state
27
+ )
28
+ st.success("Model loaded!")
29
 
30
+
31
+ def visualize_output(labels: list[str], probabilities: list[float]):
32
+ data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
33
+ by="probability", ascending=False
 
34
  )
 
 
 
 
 
35
  chart = px.bar(
36
  data,
37
  x="probability",
 
53
  "showlegend": False,
54
  }
55
  )
56
+ return chart
57
+
58
+
59
+ st.title("Zero-shot Turkish Text Classification")
60
+ method_option = st.radio(
61
+ "Select a zero-shot classification method.",
62
+ [
63
+ METHOD_OPTIONS["nli"],
64
+ METHOD_OPTIONS["nsp"],
65
+ ],
66
+ )
67
+ if method_option == METHOD_OPTIONS["nli"]:
68
+ model_option = st.selectbox(
69
+ "Select a natural language inference model.",
70
+ NLI_MODEL_OPTIONS,
71
+ )
72
+ if method_option == METHOD_OPTIONS["nsp"]:
73
+ model_option = st.selectbox(
74
+ "Select a BERT model for next sentence prediction.",
75
+ NSP_MODEL_OPTIONS,
76
+ )
77
+
78
+ if model_option != st.session_state.current_model_option:
79
+ st.session_state.current_model_option = model_option
80
+ st.session_state.current_method_option = method_option
81
+ load_model(
82
+ st.session_state.current_model_option, st.session_state.current_method_option
83
+ )
84
+
85
+
86
+ st.header("Configure prompts and labels")
87
+ col1, col2 = st.columns(2)
88
+ col1.subheader("Candidate labels")
89
+ labels = col1.text_area(
90
+ label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
91
+ value="spor,dünya,siyaset,ekonomi,sanat",
92
+ key="current_labels",
93
+ )
94
+
95
+ col1.header("Make predictions")
96
+ text = col1.text_area(
97
+ "Enter a sentence or a paragraph to classify.",
98
+ value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
99
+ key="current_text",
100
+ )
101
+ col2.subheader("Prompt template")
102
+ prompt_template = col2.text_area(
103
+ label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
104
+ value="Bu metin {} kategorisine aittir",
105
+ key="current_template",
106
+ )
107
+ col2.header("")
108
+ make_pred = col1.button("Predict")
109
+ if make_pred:
110
+ prediction = st.session_state.current_model.predict_on_texts(
111
+ [st.session_state.current_text],
112
+ candidate_labels=st.session_state.current_labels.split(","),
113
+ prompt_template=st.session_state.current_template,
114
+ )
115
+ if "scores" in prediction[0]:
116
+ chart = visualize_output(prediction[0]["labels"], prediction[0]["scores"])
117
+ elif "probabilities" in prediction[0]:
118
+ chart = visualize_output(
119
+ prediction[0]["labels"], prediction[0]["probabilities"]
120
+ )
121
+ col2.plotly_chart(chart)
models.py CHANGED
@@ -1,6 +1,6 @@
1
  METHOD_OPTIONS = {
2
  "nli": "Natural Language Inference",
3
- "nsp": "Next Sentence Prediction (not supported yet)",
4
  }
5
 
6
  NLI_MODEL_OPTIONS = [
 
1
  METHOD_OPTIONS = {
2
  "nli": "Natural Language Inference",
3
+ "nsp": "Next Sentence Prediction",
4
  }
5
 
6
  NLI_MODEL_OPTIONS = [
requirements.txt CHANGED
@@ -24,7 +24,7 @@ catalogue==2.0.6
24
  certifi==2021.10.8
25
  cffi==1.15.0
26
  charset-normalizer==2.0.7
27
- click==8.0.3
28
  codecarbon==1.2.0
29
  commonmark==0.9.1
30
  configparser==5.1.0
 
24
  certifi==2021.10.8
25
  cffi==1.15.0
26
  charset-normalizer==2.0.7
27
+ click
28
  codecarbon==1.2.0
29
  commonmark==0.9.1
30
  configparser==5.1.0