Skip to content

Commit 17620ad

Browse files
authored
Merge branch 'SakanaAI:main' into main
2 parents dccbbe3 + a369276 commit 17620ad

File tree

3 files changed

+83
-10
lines changed

3 files changed

+83
-10
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,20 @@ export VERTEXAI_LOCATION="REGION" # for Aider/LiteLLM call
118118
export VERTEXAI_PROJECT="PROJECT_ID" # for Aider/LiteLLM call
119119
```
120120

121-
#### DeepSeek API (DeepSeek-Coder-V2)
122-
121+
#### DeepSeek API (deepseek-chat, deepseek-reasoner)
123122
By default, this uses the `DEEPSEEK_API_KEY` environment variable.
124123

125124
#### OpenRouter API (Llama3.1)
126125

127126
By default, this uses the `OPENROUTER_API_KEY` environment variable.
128127

128+
#### Google Gemini
129+
We support Google Gemini models (e.g., "gemini-1.5-flash", "gemini-1.5-pro") via the [google-generativeai](https://pypi.org/project/google-generativeai) Python library. By default, it uses the environment variable:
130+
131+
```bash
132+
export GEMINI_API_KEY="YOUR GEMINI API KEY"
133+
```
134+
129135
#### Semantic Scholar API (Literature Search)
130136

131137
Our code can also optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher throughput [if you have one](https://www.semanticscholar.org/product/api), though it should work without it in principle. If you have problems with Semantic Scholar, you can skip the literature search and citation phases of paper generation.

ai_scientist/llm.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
import anthropic
66
import backoff
77
import openai
8+
import google.generativeai as genai
9+
from google.generativeai.types import GenerationConfig
810

911
MAX_NUM_TOKENS = 4096
1012

1113
AVAILABLE_LLMS = [
14+
# Anthropic models
1215
"claude-3-5-sonnet-20240620",
1316
"claude-3-5-sonnet-20241022",
17+
# OpenAI models
1418
"gpt-4o-mini-2024-07-18",
1519
"gpt-4o-2024-05-13",
1620
"gpt-4o-2024-08-06",
1721
"o1-preview-2024-09-12",
1822
"o1-mini-2024-09-12",
1923
"o1-2024-12-17",
20-
"deepseek-coder-v2-0724",
24+
# OpenRouter models
2125
"llama3.1-405b",
2226
# Anthropic Claude models via Amazon Bedrock
2327
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
@@ -31,6 +35,13 @@
3135
"vertex_ai/claude-3-5-sonnet-v2@20241022",
3236
"vertex_ai/claude-3-sonnet@20240229",
3337
"vertex_ai/claude-3-haiku@20240307",
38+
# DeepSeek models
39+
"deepseek-chat",
40+
"deepseek-coder",
41+
"deepseek-reasoner",
42+
# Google Gemini models
43+
"gemini-1.5-flash",
44+
"gemini-1.5-pro",
3445
]
3546

3647

@@ -104,7 +115,6 @@ def get_batch_responses_from_llm(
104115
new_msg_history.append(hist)
105116

106117
if print_debug:
107-
# Just print the first one.
108118
print()
109119
print("*" * 20 + " LLM START " + "*" * 20)
110120
for j, msg in enumerate(new_msg_history[0]):
@@ -191,15 +201,14 @@ def get_response_from_llm(
191201
temperature=1,
192202
max_completion_tokens=MAX_NUM_TOKENS,
193203
n=1,
194-
#stop=None,
195204
seed=0,
196205
)
197206
content = response.choices[0].message.content
198207
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
199-
elif model == "deepseek-coder-v2-0724":
208+
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
200209
new_msg_history = msg_history + [{"role": "user", "content": msg}]
201210
response = client.chat.completions.create(
202-
model="deepseek-coder",
211+
model="meta-llama/llama-3.1-405b-instruct",
203212
messages=[
204213
{"role": "system", "content": system_message},
205214
*new_msg_history,
@@ -211,10 +220,10 @@ def get_response_from_llm(
211220
)
212221
content = response.choices[0].message.content
213222
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
214-
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
223+
elif model in ["deepseek-chat", "deepseek-coder"]:
215224
new_msg_history = msg_history + [{"role": "user", "content": msg}]
216225
response = client.chat.completions.create(
217-
model="meta-llama/llama-3.1-405b-instruct",
226+
model=model,
218227
messages=[
219228
{"role": "system", "content": system_message},
220229
*new_msg_history,
@@ -226,6 +235,34 @@ def get_response_from_llm(
226235
)
227236
content = response.choices[0].message.content
228237
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
238+
elif model in ["deepseek-reasoner"]:
239+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
240+
response = client.chat.completions.create(
241+
model=model,
242+
messages=[
243+
{"role": "system", "content": system_message},
244+
*new_msg_history,
245+
],
246+
n=1,
247+
stop=None,
248+
)
249+
content = response.choices[0].message.content
250+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
251+
elif "gemini" in model:
252+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
253+
gemini_contents = [{"role": "system", "parts": system_message}]
254+
for m in new_msg_history:
255+
gemini_contents.append({"role": m["role"], "parts": m["content"]})
256+
response = client.generate_content(
257+
contents=gemini_contents,
258+
generation_config=GenerationConfig(
259+
temperature=temperature,
260+
max_output_tokens=MAX_NUM_TOKENS,
261+
candidate_count=1,
262+
),
263+
)
264+
content = response.text
265+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
229266
else:
230267
raise ValueError(f"Model {model} not supported.")
231268

@@ -287,7 +324,7 @@ def create_client(model):
287324
elif model in ["o1-preview-2024-09-12", "o1-mini-2024-09-12"]:
288325
print(f"Using OpenAI API with model {model}.")
289326
return openai.OpenAI(), model
290-
elif model == "deepseek-coder-v2-0724":
327+
elif model in ["deepseek-chat", "deepseek-reasoner"]:
291328
print(f"Using OpenAI API with {model}.")
292329
return openai.OpenAI(
293330
api_key=os.environ["DEEPSEEK_API_KEY"],
@@ -299,5 +336,10 @@ def create_client(model):
299336
api_key=os.environ["OPENROUTER_API_KEY"],
300337
base_url="https://openrouter.ai/api/v1"
301338
), "meta-llama/llama-3.1-405b-instruct"
339+
elif "gemini" in model:
340+
print(f"Using Google Generative AI with model {model}.")
341+
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
342+
client = genai.GenerativeModel(model)
343+
return client, model
302344
else:
303345
raise ValueError(f"Model {model} not supported.")

launch_scientist.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ def get_available_gpus(gpu_ids=None):
9898
return list(range(torch.cuda.device_count()))
9999

100100

101+
def check_latex_dependencies():
102+
"""
103+
Check if required LaTeX dependencies are installed on the system.
104+
Returns True if all dependencies are found, False otherwise.
105+
"""
106+
import shutil
107+
import sys
108+
109+
required_dependencies = ['pdflatex', 'chktex']
110+
missing_deps = []
111+
112+
for dep in required_dependencies:
113+
if shutil.which(dep) is None:
114+
missing_deps.append(dep)
115+
116+
if missing_deps:
117+
print("Error: Required LaTeX dependencies not found:", file=sys.stderr)
118+
return False
119+
120+
return True
121+
101122
def worker(
102123
queue,
103124
base_dir,
@@ -304,6 +325,10 @@ def do_idea(
304325

305326
print(f"Using GPUs: {available_gpus}")
306327

328+
# Check LaTeX dependencies before proceeding
329+
if args.writeup == "latex" and not check_latex_dependencies():
330+
sys.exit(1)
331+
307332
# Create client
308333
client, client_model = create_client(args.model)
309334

0 commit comments

Comments
 (0)