Anirudh Esthuri commited on
Commit
16ab50a
Β·
1 Parent(s): 3a73f5d

Switch Gemini models to use Google API directly instead of AWS Bedrock

Browse files
Files changed (3) hide show
  1. llm.py +91 -80
  2. model_config.py +4 -13
  3. requirements.txt +1 -0
llm.py CHANGED
@@ -8,6 +8,39 @@ import requests
8
  from dotenv import load_dotenv
9
  from model_config import MODEL_TO_PROVIDER, MODEL_TO_INFERENCE_PROFILE_ARN
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ──────────────────────────────────────────────────────────────
12
  # Load environment variables
13
  load_dotenv()
@@ -177,101 +210,64 @@ def chat(messages, persona):
177
  print("Using google (Gemini): ", MODEL_STRING)
178
  t0 = time.time()
179
 
180
- # Add system prompt for better behavior
181
- system_prompt = ""
182
-
183
- # Convert messages to Gemini format
184
- # Gemini uses "user" and "model" roles, and content is an array
185
- gemini_messages = []
186
- for msg in messages:
187
- role = msg.get("role", "user")
188
- # Gemini uses "model" instead of "assistant"
189
- if role == "assistant":
190
- role = "model"
191
- gemini_messages.append({
192
- "role": role,
193
- "parts": [{"text": msg["content"]}]
194
- })
195
-
196
  try:
197
- bedrock_runtime = get_bedrock_client()
198
 
199
- # Use inference profile ARN if available (for provisioned throughput models)
200
- # Otherwise use modelId (for on-demand models)
201
- invoke_kwargs = {
202
- "contentType": "application/json",
203
- "accept": "application/json",
204
- "body": json.dumps(
205
- {
206
- "contents": gemini_messages,
207
- "generationConfig": {
208
- "maxOutputTokens": 4000,
209
- "temperature": 0.3,
210
- }
211
- }
212
- ),
213
- }
214
 
215
- # Add system instruction if provided
216
- if system_prompt:
217
- invoke_kwargs["body"] = json.dumps(
218
- {
219
- "contents": gemini_messages,
220
- "systemInstruction": {
221
- "parts": [{"text": system_prompt}]
222
- },
223
- "generationConfig": {
224
- "maxOutputTokens": 4000,
225
- "temperature": 0.3,
226
- }
227
- }
228
- )
229
 
230
- # Check if this model has an inference profile ARN (provisioned throughput)
231
- # For provisioned throughput, use the ARN as the modelId
232
- if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
233
- invoke_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
234
- else:
235
- invoke_kwargs["modelId"] = MODEL_STRING
236
 
237
- response = bedrock_runtime.invoke_model(**invoke_kwargs)
 
 
 
 
 
 
 
 
238
 
239
  dt = time.time() - t0
240
- body = json.loads(response["body"].read())
 
 
 
 
 
241
  except ValueError as e:
242
  # Re-raise ValueError (credential errors) as-is
243
  raise
244
  except Exception as e:
245
  error_msg = str(e)
246
- if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
247
  raise ValueError(
248
- f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
249
- f"Error: {error_msg}. "
250
- "Please verify the model ID is correct and the model is available in your AWS region. "
251
- "Common Gemini model IDs: 'google.gemini-pro-v1' or 'google.gemini-2.0-flash-exp'"
252
  ) from e
253
- elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
254
  raise ValueError(
255
- f"AWS Bedrock authentication failed: {error_msg}. "
256
- "Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
257
- "are correct and have Bedrock access permissions."
 
258
  ) from e
259
  raise
260
-
261
- # Extract text from Gemini response
262
- # Gemini response format: {"candidates": [{"content": {"parts": [{"text": "..."}]}}]}
263
- text = ""
264
- if "candidates" in body and len(body["candidates"]) > 0:
265
- candidate = body["candidates"][0]
266
- if "content" in candidate and "parts" in candidate["content"]:
267
- for part in candidate["content"]["parts"]:
268
- if "text" in part:
269
- text += part["text"]
270
-
271
- text = text.strip()
272
- total_tok = len(text.split())
273
-
274
- return text, dt, total_tok, (total_tok / dt if dt else total_tok)
275
  elif provider == "deepseek":
276
  print("Using deepseek: ", MODEL_STRING)
277
  t0 = time.time()
@@ -477,8 +473,8 @@ def check_credentials():
477
  # print(f"Ollama connection failed: {e}")
478
  # return False
479
 
480
- # Check if using Bedrock providers (anthropic, google, meta, mistral, deepseek)
481
- bedrock_providers = ["anthropic", "google"]
482
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
483
  # Test AWS Bedrock connection by trying to invoke a simple model
484
  try:
@@ -519,6 +515,21 @@ def check_credentials():
519
  return False
520
  return True
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  return True
523
 
524
 
 
8
  from dotenv import load_dotenv
9
  from model_config import MODEL_TO_PROVIDER, MODEL_TO_INFERENCE_PROFILE_ARN
10
 
11
+ # Lazy initialization of Google Gemini client
12
+ _google_client = None
13
+
14
+ def get_google_client():
15
+ """Get or create the Google Gemini client with proper error handling."""
16
+ global _google_client
17
+ if _google_client is None:
18
+ try:
19
+ import google.generativeai as genai
20
+ except ImportError:
21
+ raise ValueError(
22
+ "google-generativeai package not installed. "
23
+ "Please add 'google-generativeai' to requirements.txt"
24
+ )
25
+
26
+ google_api_key = os.getenv("GOOGLE_API_KEY", "").strip()
27
+ if not google_api_key:
28
+ raise ValueError(
29
+ "Google API key not found. Please set GOOGLE_API_KEY "
30
+ "as a secret in Hugging Face Spaces settings."
31
+ )
32
+
33
+ try:
34
+ genai.configure(api_key=google_api_key)
35
+ _google_client = genai
36
+ except Exception as e:
37
+ raise ValueError(
38
+ f"Failed to initialize Google Gemini client: {str(e)}. "
39
+ "Please verify your GOOGLE_API_KEY is correct."
40
+ ) from e
41
+
42
+ return _google_client
43
+
44
  # ──────────────────────────────────────────────────────────────
45
  # Load environment variables
46
  load_dotenv()
 
210
  print("Using google (Gemini): ", MODEL_STRING)
211
  t0 = time.time()
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  try:
214
+ genai = get_google_client()
215
 
216
+ # Get the model
217
+ model = genai.GenerativeModel(MODEL_STRING)
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ # Convert messages to Gemini format
220
+ # Gemini API expects a chat history format
221
+ chat_history = []
222
+ for msg in messages:
223
+ role = msg.get("role", "user")
224
+ content = msg.get("content", "")
225
+ # Gemini uses "model" instead of "assistant"
226
+ if role == "assistant":
227
+ role = "model"
228
+ chat_history.append({
229
+ "role": role,
230
+ "parts": [content]
231
+ })
 
232
 
233
+ # Start a chat session with history
234
+ chat = model.start_chat(history=chat_history[:-1] if len(chat_history) > 1 else [])
 
 
 
 
235
 
236
+ # Send the last message
237
+ last_message = chat_history[-1]["parts"][0] if chat_history else ""
238
+ response = chat.send_message(
239
+ last_message,
240
+ generation_config=genai.types.GenerationConfig(
241
+ max_output_tokens=4000,
242
+ temperature=0.3,
243
+ )
244
+ )
245
 
246
  dt = time.time() - t0
247
+ text = response.text.strip()
248
+
249
+ # Calculate tokens (approximate)
250
+ total_tok = len(text.split())
251
+
252
+ return text, dt, total_tok, (total_tok / dt if dt else total_tok)
253
  except ValueError as e:
254
  # Re-raise ValueError (credential errors) as-is
255
  raise
256
  except Exception as e:
257
  error_msg = str(e)
258
+ if "API key" in error_msg or "invalid" in error_msg.lower() or "401" in error_msg or "403" in error_msg:
259
  raise ValueError(
260
+ f"Google API authentication failed: {error_msg}. "
261
+ "Please verify your GOOGLE_API_KEY secret is correct and has Gemini API access."
 
 
262
  ) from e
263
+ elif "not found" in error_msg.lower() or "404" in error_msg:
264
  raise ValueError(
265
+ f"Invalid Gemini model ID: '{MODEL_STRING}'. "
266
+ f"Error: {error_msg}. "
267
+ "Please verify the model ID is correct. "
268
+ "Common Gemini model IDs: 'gemini-3.0-pro', 'gemini-2.5-flash', 'gemini-1.5-pro', 'gemini-1.5-flash'"
269
  ) from e
270
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  elif provider == "deepseek":
272
  print("Using deepseek: ", MODEL_STRING)
273
  t0 = time.time()
 
473
  # print(f"Ollama connection failed: {e}")
474
  # return False
475
 
476
+ # Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
477
+ bedrock_providers = ["anthropic"]
478
  if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
479
  # Test AWS Bedrock connection by trying to invoke a simple model
480
  try:
 
515
  return False
516
  return True
517
 
518
+ # For Google Gemini, check API key
519
+ if MODEL_TO_PROVIDER.get(MODEL_STRING) == "google":
520
+ required = ["GOOGLE_API_KEY"]
521
+ missing = [var for var in required if not os.getenv(var)]
522
+ if missing:
523
+ print(f"Missing environment variables: {missing}")
524
+ return False
525
+ # Try to initialize the client to verify the key works
526
+ try:
527
+ get_google_client()
528
+ return True
529
+ except Exception as e:
530
+ print(f"Google API client initialization failed: {e}")
531
+ return False
532
+
533
  return True
534
 
535
 
model_config.py CHANGED
@@ -11,8 +11,8 @@ PROVIDER_MODEL_MAP = {
11
  "anthropic.claude-opus-4-20250514-v1:0",
12
  ],
13
  "google": [
14
- "google.gemini-3.0-pro-v1:0",
15
- "google.gemini-2.5-flash-v1:0",
16
  ],
17
  }
18
 
@@ -32,8 +32,8 @@ MODEL_DISPLAY_NAMES = {
32
  "anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
33
  "anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
34
  "anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
35
- "google.gemini-3.0-pro-v1:0": "AWS Bedrock - Google - Gemini 3.0 Pro",
36
- "google.gemini-2.5-flash-v1:0": "AWS Bedrock - Google - Gemini 2.5 Flash",
37
  }
38
 
39
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
@@ -58,12 +58,3 @@ opus_arn = os.getenv("BEDROCK_OPUS_4_ARN", "").strip()
58
  if opus_arn:
59
  MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
60
 
61
- # Gemini 3.0 Pro
62
- gemini_3_arn = os.getenv("BEDROCK_GEMINI_3_ARN", "").strip()
63
- if gemini_3_arn:
64
- MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-3.0-pro-v1:0"] = gemini_3_arn
65
-
66
- # Gemini 2.5 Flash
67
- gemini_2_5_arn = os.getenv("BEDROCK_GEMINI_2_5_ARN", "").strip()
68
- if gemini_2_5_arn:
69
- MODEL_TO_INFERENCE_PROFILE_ARN["google.gemini-2.5-flash-v1:0"] = gemini_2_5_arn
 
11
  "anthropic.claude-opus-4-20250514-v1:0",
12
  ],
13
  "google": [
14
+ "gemini-3.0-pro",
15
+ "gemini-2.5-flash",
16
  ],
17
  }
18
 
 
32
  "anthropic.claude-haiku-4-5-20251001-v1:0": "AWS Bedrock - Anthropic - Claude Haiku 4.5",
33
  "anthropic.claude-sonnet-4-5-20250929-v1:0": "AWS Bedrock - Anthropic - Claude Sonnet 4.5",
34
  "anthropic.claude-opus-4-20250514-v1:0": "AWS Bedrock - Anthropic - Claude Opus 4",
35
+ "gemini-3.0-pro": "Google - Gemini 3.0 Pro",
36
+ "gemini-2.5-flash": "Google - Gemini 2.5 Flash",
37
  }
38
 
39
  MODEL_CHOICES = [model for models in PROVIDER_MODEL_MAP.values() for model in models]
 
58
  if opus_arn:
59
  MODEL_TO_INFERENCE_PROFILE_ARN["anthropic.claude-opus-4-20250514-v1:0"] = opus_arn
60
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -11,3 +11,4 @@ tiktoken
11
  pydantic
12
  boto3
13
  huggingface_hub
 
 
11
  pydantic
12
  boto3
13
  huggingface_hub
14
+ google-generativeai