tenet commited on
Commit
b6b5359
·
verified ·
1 Parent(s): 4e40cb6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -18
main.py CHANGED
@@ -1,8 +1,7 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import math
4
  import random
5
  import numpy as np
 
6
 
7
  # Define the Node class for MCTS
8
  class Node:
@@ -111,7 +110,7 @@ class GameState:
111
 
112
  # Initialize the RWKV model and tokenizer
113
  model_name = "BlinkDL/rwkv-4-raven"
114
- tokenizer = AutoTokenizer.from_pretrained(model_name)
115
  model = AutoModelForCausalLM.from_pretrained(model_name)
116
 
117
  # Generate Chain-of-Thought
@@ -129,21 +128,8 @@ def mcts_with_cot(initial_state):
129
  cot = generate_cot(best_state)
130
  return best_state, cot
131
 
132
- # Gradio Interface
133
  def run_mcts_cot(initial_board):
134
  initial_state = GameState(initial_board, 1)
135
  best_state, cot = mcts_with_cot(initial_state)
136
- return str(best_state), cot
137
-
138
- # Create the Gradio interface
139
- initial_board = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
140
- iface = gr.Interface(
141
- fn=run_mcts_cot,
142
- inputs=gr.inputs.JSON(),
143
- outputs=["text", "text"],
144
- title="RWKV CoT Demo for MCTS",
145
- description="This demo uses RWKV to generate Chain-of-Thought reasoning to guide the MCTS algorithm in a Tic-Tac-Toe game."
146
- )
147
-
148
- # Launch the interface
149
- iface.launch()
 
 
 
1
  import math
2
  import random
3
  import numpy as np
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # Define the Node class for MCTS
7
  class Node:
 
110
 
111
  # Initialize the RWKV model and tokenizer
112
  model_name = "BlinkDL/rwkv-4-raven"
113
+ tokenizer = AutoModelForCausalLM.from_pretrained(model_name)
114
  model = AutoModelForCausalLM.from_pretrained(model_name)
115
 
116
  # Generate Chain-of-Thought
 
128
  cot = generate_cot(best_state)
129
  return best_state, cot
130
 
131
+ # Function to be called by Gradio
132
  def run_mcts_cot(initial_board):
133
  initial_state = GameState(initial_board, 1)
134
  best_state, cot = mcts_with_cot(initial_state)
135
+ return str(best_state), cot