AlexFocus commited on
Commit
3b2a4e6
Β·
1 Parent(s): 1e835c9

imagen generation added

Browse files
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  openai==2.8.0
2
- gradio[mcp]>=5.49.0
 
 
1
  openai==2.8.0
2
+ gradio[mcp]>=5.49.0
3
+ huggingface-hub>=0.20.0
src/learnbee/constants.py CHANGED
@@ -39,15 +39,8 @@ LANGUAGES = [
39
  "French",
40
  "German",
41
  "Italian",
42
- "Portuguese",
43
  "Chinese",
44
  "Japanese",
45
- "Korean",
46
- "Arabic",
47
- "Russian",
48
- "Dutch",
49
- "Polish",
50
- "Turkish",
51
  "Hindi"
52
  ]
53
 
 
39
  "French",
40
  "German",
41
  "Italian",
 
42
  "Chinese",
43
  "Japanese",
 
 
 
 
 
 
44
  "Hindi"
45
  ]
46
 
src/learnbee/image_generator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image generation module using Hugging Face Inference API."""
2
+
3
+ import os
4
+ from typing import Optional
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import InferenceClient
7
+
8
+ load_dotenv()
9
+
10
+
11
+ class ImageGenerator:
12
+ """Generate images using Hugging Face Inference API."""
13
+
14
+ def __init__(self, model: str = "black-forest-labs/FLUX.1-schnell"):
15
+ """
16
+ Initialize the image generator.
17
+
18
+ Args:
19
+ model: Hugging Face model ID for image generation
20
+ """
21
+ self.api_token = os.getenv("HF_TOKEN")
22
+ self.model = model
23
+
24
+ if not self.api_token:
25
+ raise ValueError("HF_TOKEN not found in environment variables")
26
+
27
+ # Initialize Hugging Face Inference Client
28
+ self.client = InferenceClient(token=self.api_token)
29
+
30
+ def generate_image(self, prompt: str) -> Optional[bytes]:
31
+ """
32
+ Generate an image from a text prompt.
33
+
34
+ Args:
35
+ prompt: Text description of the image to generate
36
+
37
+ Returns:
38
+ Image bytes if successful, None otherwise
39
+ """
40
+ # Enhance prompt for child-friendly, educational content
41
+ enhanced_prompt = self._enhance_prompt_for_children(prompt)
42
+
43
+ try:
44
+ # Use the text_to_image method from InferenceClient
45
+ image = self.client.text_to_image(
46
+ enhanced_prompt,
47
+ model=self.model
48
+ )
49
+
50
+ # Convert PIL Image to bytes
51
+ from io import BytesIO
52
+ img_byte_arr = BytesIO()
53
+ image.save(img_byte_arr, format='PNG')
54
+ img_byte_arr.seek(0)
55
+
56
+ return img_byte_arr.read()
57
+
58
+ except Exception as e:
59
+ print(f"Error generating image: {str(e)}")
60
+ return None
61
+
62
+ def _enhance_prompt_for_children(self, prompt: str) -> str:
63
+ """
64
+ Enhance the prompt to ensure child-friendly, educational images.
65
+
66
+ Args:
67
+ prompt: Original prompt
68
+
69
+ Returns:
70
+ Enhanced prompt
71
+ """
72
+ # Add style modifiers for child-friendly content
73
+ enhancements = [
74
+ "child-friendly",
75
+ "colorful",
76
+ "educational illustration",
77
+ "cartoon style",
78
+ "bright and cheerful"
79
+ ]
80
+
81
+ # Combine original prompt with enhancements
82
+ enhanced = f"{prompt}, {', '.join(enhancements)}"
83
+
84
+ return enhanced
85
+
86
+ def detect_image_request(self, message: str) -> Optional[str]:
87
+ """
88
+ Detect if a message contains an image request and extract the subject.
89
+
90
+ Args:
91
+ message: User's message
92
+
93
+ Returns:
94
+ Subject to generate image for, or None if no request detected
95
+ """
96
+ message_lower = message.lower()
97
+
98
+ # Keywords that indicate image request
99
+ image_keywords = [
100
+ "show me", "muΓ©strame", "muestra",
101
+ "draw", "dibuja", "dibujar",
102
+ "picture of", "imagen de", "foto de",
103
+ "what does", "cΓ³mo es", "como es",
104
+ "i want to see", "quiero ver",
105
+ "can you show", "puedes mostrar"
106
+ ]
107
+
108
+ # Check if message contains image request keywords
109
+ for keyword in image_keywords:
110
+ if keyword in message_lower:
111
+ # Extract subject (simplified - could be improved with NLP)
112
+ # Remove the keyword and get the remaining text
113
+ subject = message_lower.replace(keyword, "").strip()
114
+ # Remove common words
115
+ subject = subject.replace("a ", "").replace("an ", "").replace("the ", "")
116
+ subject = subject.replace("un ", "").replace("una ", "").replace("el ", "").replace("la ", "")
117
+
118
+ if subject:
119
+ return subject
120
+
121
+ return None
src/learnbee/prompts.py CHANGED
@@ -76,6 +76,13 @@ def generate_tutor_system_prompt(
76
  "- Be warm, enthusiastic, and patient. Show excitement about problem-solving!\n"
77
  "- Use the child's name when possible (refer to them as 'you' or 'little learner').\n\n"
78
 
 
 
 
 
 
 
 
79
  "TEACHING STRATEGIES BY DIFFICULTY LEVEL:\n"
80
  f"- {difficulty_level.upper()} level:\n"
81
  f"{difficulty_instruction}\n"
 
76
  "- Be warm, enthusiastic, and patient. Show excitement about problem-solving!\n"
77
  "- Use the child's name when possible (refer to them as 'you' or 'little learner').\n\n"
78
 
79
+ "IMAGE GENERATION CAPABILITY:\n"
80
+ "- You can suggest the child ask to see images of things they're learning about!\n"
81
+ "- When discussing visual concepts (animals, objects, places), encourage them: 'Would you like to see what a [subject] looks like?'\n"
82
+ "- The child can request images by saying things like 'show me a dinosaur' or 'I want to see a rocket'\n"
83
+ "- Images help visual learners understand concepts better - use this feature to enhance learning!\n"
84
+ "- After an image is shown, ask questions about what they see in the image.\n\n"
85
+
86
  "TEACHING STRATEGIES BY DIFFICULTY LEVEL:\n"
87
  f"- {difficulty_level.upper()} level:\n"
88
  f"{difficulty_instruction}\n"
src/learnbee/tutor_handlers.py CHANGED
@@ -11,6 +11,9 @@ from learnbee.prompts import generate_tutor_system_prompt
11
 
12
  from learnbee.session_state import SessionState
13
  from learnbee.gamification import GamificationTracker
 
 
 
14
 
15
 
16
  def load_lesson_content(lesson_name, selected_tutor, selected_language, progress=gr.Progress()):
@@ -239,7 +242,7 @@ def custom_respond(
239
  session_state, gamification_tracker
240
  ):
241
  """
242
- Custom respond function with educational system prompt, adaptive personalization, and gamification.
243
 
244
  Args:
245
  message: User's message
@@ -260,6 +263,43 @@ def custom_respond(
260
 
261
  if not lesson_content:
262
  lesson_content = get_lesson_content(lesson_name, LESSON_CONTENT_MAX_LENGTH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  # Analyze the child's message for adaptive personalization
265
  message_analysis = session_state.analyze_message(message)
@@ -291,12 +331,25 @@ def custom_respond(
291
  adaptive_context=adaptive_context
292
  )
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # Call the respond method with educational system prompt
295
  call_llm = LLMCall()
296
  response_text = ""
297
  for response in call_llm.respond(
298
  message,
299
- history,
300
  system_prompt=system_prompt,
301
  tutor_name=selected_tutor,
302
  difficulty_level=difficulty_level
 
11
 
12
  from learnbee.session_state import SessionState
13
  from learnbee.gamification import GamificationTracker
14
+ from learnbee.image_generator import ImageGenerator
15
+ import os
16
+ import time
17
 
18
 
19
  def load_lesson_content(lesson_name, selected_tutor, selected_language, progress=gr.Progress()):
 
242
  session_state, gamification_tracker
243
  ):
244
  """
245
+ Custom respond function with educational system prompt, adaptive personalization, gamification, and image generation.
246
 
247
  Args:
248
  message: User's message
 
263
 
264
  if not lesson_content:
265
  lesson_content = get_lesson_content(lesson_name, LESSON_CONTENT_MAX_LENGTH)
266
+
267
+ # Check if message contains an image request
268
+ image_gen = ImageGenerator()
269
+ image_subject = image_gen.detect_image_request(message)
270
+
271
+ if image_subject:
272
+ # Generate image
273
+ yield "🎨 Generating image of {}... Please wait!".format(image_subject), gamification_tracker.get_progress_html()
274
+
275
+ image_bytes = image_gen.generate_image(image_subject)
276
+
277
+ if image_bytes:
278
+ # Save image to generated_images directory with absolute path
279
+ images_dir = os.path.abspath("./generated_images")
280
+ os.makedirs(images_dir, exist_ok=True)
281
+ timestamp = int(time.time())
282
+ image_filename = f"image_{timestamp}.png"
283
+ image_path = os.path.join(images_dir, image_filename)
284
+
285
+ with open(image_path, "wb") as f:
286
+ f.write(image_bytes)
287
+
288
+ # Award star for creative request
289
+ gamification_tracker.award_star("Creative image request!")
290
+ session_state.total_messages += 1
291
+
292
+ # Return message with gr.Image component for proper display
293
+ # First yield the text message
294
+ yield f"Here's your image of {image_subject}! 🎨", gamification_tracker.get_progress_html()
295
+
296
+ # Then yield the image as a separate message using gr.Image
297
+ import gradio as gr
298
+ yield gr.Image(value=image_path, label=image_subject, show_label=False, height=400), gamification_tracker.get_progress_html()
299
+ return
300
+ else:
301
+ yield f"Sorry, I couldn't generate the image right now. Let's continue learning! 😊", gamification_tracker.get_progress_html()
302
+ return
303
 
304
  # Analyze the child's message for adaptive personalization
305
  message_analysis = session_state.analyze_message(message)
 
331
  adaptive_context=adaptive_context
332
  )
333
 
334
+ # Filter history to remove gr.Image components (they can't be sent to LLM)
335
+ # Keep only text messages for LLM context
336
+ filtered_history = []
337
+ for msg in history:
338
+ if isinstance(msg, dict):
339
+ # Check if content is a string (text message)
340
+ if isinstance(msg.get("content"), str):
341
+ filtered_history.append(msg)
342
+ # Skip messages with gr.Image or other components
343
+ elif isinstance(msg, (list, tuple)) and len(msg) == 2:
344
+ # Old tuple format - keep it
345
+ filtered_history.append(msg)
346
+
347
  # Call the respond method with educational system prompt
348
  call_llm = LLMCall()
349
  response_text = ""
350
  for response in call_llm.respond(
351
  message,
352
+ filtered_history, # Use filtered history without images
353
  system_prompt=system_prompt,
354
  tutor_name=selected_tutor,
355
  difficulty_level=difficulty_level