Spaces:
Configuration error
Configuration error
Tristan Thrush
commited on
Commit
Β·
90b6f98
1
Parent(s):
013ce7b
synched model memory with conversation, and make sure it is wiped for next hit
Browse files
app.py
CHANGED
|
@@ -98,6 +98,13 @@ chatbot_4 = ConversationChain(
|
|
| 98 |
memory=ConversationBufferMemory(ai_prefix="Assistant"),
|
| 99 |
)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
demo = gr.Blocks()
|
| 102 |
|
| 103 |
with demo:
|
|
@@ -130,17 +137,17 @@ with demo:
|
|
| 130 |
response_3 = chatbot_3.predict(input=txt)
|
| 131 |
response_4 = chatbot_4.predict(input=txt)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
|
| 139 |
state["cnt"] += 1
|
| 140 |
|
| 141 |
new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
|
| 142 |
|
| 143 |
-
state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"
|
| 144 |
state["past_user_inputs"].append(txt)
|
| 145 |
|
| 146 |
past_conversation_string = "<br />".join(["<br />".join(["π: " + user_input, "π€: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
|
|
@@ -150,7 +157,7 @@ with demo:
|
|
| 150 |
done = state["cnt"] == TOTAL_CNT
|
| 151 |
state["generated_responses"].append(selected_response)
|
| 152 |
state["data"][-1]["selected_response"] = selected_response
|
| 153 |
-
state["data"][-1]["selected_model"] = state["data"][-1]["
|
| 154 |
if state["cnt"] == TOTAL_CNT:
|
| 155 |
# Write the HIT data to our local dataset because the worker has
|
| 156 |
# submitted everything now.
|
|
@@ -172,6 +179,21 @@ with demo:
|
|
| 172 |
else:
|
| 173 |
toggle_final_submit_preview = gr.update(visible=done)
|
| 174 |
toggle_final_submit = gr.update(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
text_input = gr.update(visible=False) if done else gr.update(visible=True)
|
| 176 |
return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
|
| 177 |
|
|
|
|
| 98 |
memory=ConversationBufferMemory(ai_prefix="Assistant"),
|
| 99 |
)
|
| 100 |
|
| 101 |
+
model_id2model = {
|
| 102 |
+
"google/flan-t5-xl": chatbot_1,
|
| 103 |
+
"bigscience/bloom": chatbot_2,
|
| 104 |
+
"bigscience/T0_3B": chatbot_3,
|
| 105 |
+
"EleutherAI/gpt-j-6B": chatbot_4
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
demo = gr.Blocks()
|
| 109 |
|
| 110 |
with demo:
|
|
|
|
| 137 |
response_3 = chatbot_3.predict(input=txt)
|
| 138 |
response_4 = chatbot_4.predict(input=txt)
|
| 139 |
|
| 140 |
+
response2model_id = {}
|
| 141 |
+
response2model_id[response_1] = chatbot_1.llm.repo_id
|
| 142 |
+
response2model_id[response_2] = chatbot_2.llm.repo_id
|
| 143 |
+
response2model_id[response_3] = chatbot_3.llm.repo_id
|
| 144 |
+
response2model_id[response_4] = chatbot_4.llm.repo_id
|
| 145 |
|
| 146 |
state["cnt"] += 1
|
| 147 |
|
| 148 |
new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
|
| 149 |
|
| 150 |
+
state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model_id": response2model_id})
|
| 151 |
state["past_user_inputs"].append(txt)
|
| 152 |
|
| 153 |
past_conversation_string = "<br />".join(["<br />".join(["π: " + user_input, "π€: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
|
|
|
|
| 157 |
done = state["cnt"] == TOTAL_CNT
|
| 158 |
state["generated_responses"].append(selected_response)
|
| 159 |
state["data"][-1]["selected_response"] = selected_response
|
| 160 |
+
state["data"][-1]["selected_model"] = state["data"][-1]["response2model_id"][selected_response]
|
| 161 |
if state["cnt"] == TOTAL_CNT:
|
| 162 |
# Write the HIT data to our local dataset because the worker has
|
| 163 |
# submitted everything now.
|
|
|
|
| 179 |
else:
|
| 180 |
toggle_final_submit_preview = gr.update(visible=done)
|
| 181 |
toggle_final_submit = gr.update(visible=False)
|
| 182 |
+
|
| 183 |
+
if done:
|
| 184 |
+
# Wipe the memory completely because we will be starting a new hit soon.
|
| 185 |
+
chatbot_1.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
| 186 |
+
chatbot_2.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
| 187 |
+
chatbot_3.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
| 188 |
+
chatbot_4.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
| 189 |
+
else:
|
| 190 |
+
# Sync all of the model's memories with the conversation path that
|
| 191 |
+
# was actually taken.
|
| 192 |
+
chatbot_1.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
| 193 |
+
chatbot_2.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
| 194 |
+
chatbot_3.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
| 195 |
+
chatbot_4.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
| 196 |
+
|
| 197 |
text_input = gr.update(visible=False) if done else gr.update(visible=True)
|
| 198 |
return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
|
| 199 |
|