from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr import torch import autopep8 import glob import re import os from huggingface_hub import hf_hub_download # ========================== # Utility functions # ========================== def normalize_indentation(code): """ Normalize indentation in example code by removing excessive tabs. Also removes any backslash characters. """ code = code.replace("\\", "") lines = code.split("\n") if not lines: return "" fixed_lines = [] indent_fix_mode = False for i, line in enumerate(lines): if line.strip().startswith("def "): fixed_lines.append(line) indent_fix_mode = True elif indent_fix_mode and line.strip(): # For indented lines in a function if line.startswith("\t\t"): # Two tabs fixed_lines.append("\t" + line[2:]) # Replace with one tab elif line.startswith(" "): # 8 spaces (2 levels) fixed_lines.append(" " + line[8:]) # Replace with 4 spaces else: fixed_lines.append(line) else: fixed_lines.append(line) return "\n".join(fixed_lines) def clear_text(text): """ Cleans text from escape sequences while preserving original formatting. """ temp_newline = "TEMP_NEWLINE_PLACEHOLDER" temp_tab = "TEMP_TAB_PLACEHOLDER" text = text.replace("\\n", temp_newline) text = text.replace("\\t", temp_tab) text = text.replace("\\", "") text = text.replace(temp_newline, "\n") text = text.replace(temp_tab, "\t") return text def encode_text(text): """ Encodes control characters into escape sequences. """ text = text.replace("\n", "\\n") text = text.replace("\t", "\\t") return text def format_code(code): """ Format Python code using autopep8 with aggressive settings. """ try: formatted_code = autopep8.fix_code( code, options={ "aggressive": 2, "max_line_length": 88, "indent_size": 4, }, ) # Additional formatting for consistent spacing around parentheses and operators formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: formatted_code = formatted_code.replace(f"{op} ", op + " ") formatted_code = formatted_code.replace(f" {op}", " " + op) formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code) return formatted_code except Exception as e: print(f"Error formatting code: {str(e)}") return code def fix_common_syntax_issues(code): """ Fix common syntax issues in generated code without modifying indentation. """ lines = code.split("\n") fixed_lines = [] for line in lines: stripped = line.strip() if ( stripped.startswith("if ") or stripped.startswith("elif ") or stripped.startswith("else") or stripped.startswith("for ") or stripped.startswith("while ") or stripped.startswith("def ") or stripped.startswith("class ") ): if not stripped.endswith(":") and not stripped.endswith("\\"): line = line.rstrip() + ":" fixed_lines.append(line) code = "\n".join(fixed_lines) # Fix mismatched quotes quote_chars = ['"', "'"] for quote in quote_chars: if code.count(quote) % 2 != 0: lines = code.split("\n") for i, line in enumerate(lines): if line.count(quote) % 2 != 0: lines[i] = line.rstrip() + quote break code = "\n".join(lines) # Fix missing parentheses in function calls pattern = r"(\w+)\s*\([^)]*$" if re.search(pattern, code): lines = code.split("\n") for i, line in enumerate(lines): if re.search(pattern, line) and not any( lines[j].strip().startswith(")") for j in range(i + 1, min(i + 3, len(lines))) ): lines[i] = line.rstrip() + ")" code = "\n".join(lines) return code def load_example_from_file(example_path): """ Load example from a file with format: description_BREAK_code where 'code' uses \\n and \\t for formatting. """ try: with open(example_path, "r") as f: content = f.read() parts = content.split("_BREAK_") if len(parts) == 2: description = parts[0].strip() code = parts[1].strip() code = code.replace("\\n", "\n").replace("\\t", "\t") code = normalize_indentation(code) return description, code else: print(f"Invalid format in example file: {example_path}") return "", "" except Exception as e: print(f"Error loading example file {example_path}: {str(e)}") return "", "" def find_example_files(): """ Find all raw.in example files in the examples directory. """ example_files = glob.glob("examples/*/raw.in") return example_files # ========================== # Load model from HF Hub # ========================== BASE_MODEL_ID = "Salesforce/codet5p-770m" FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs" FINETUNED_FILENAME = "pytorch_model.bin" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Loading tokenizer from base model: {BASE_MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) print(f"Loading base model: {BASE_MODEL_ID}") model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) model.to(device) print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}") ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME) print(f"Loading state_dict from: {ckpt_path}") state_dict = torch.load(ckpt_path, map_location="cpu") if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") model.eval() # ========================== # Gradio logic # ========================== # State variables current_code = None bug_counter = 0 def generate_bugged_code(description, code, chat_history, is_first_time): global current_code, bug_counter if chat_history is None: chat_history = [] if is_first_time: bug_counter = 0 current_code = None chat_history = [] bug_counter += 1 if bug_counter == 1: input_for_model = code input_type = "original" else: if current_code is None: return chat_history, gr.update(value=""), False input_for_model = current_code input_type = "previous bugged code" print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") encoded_code = encode_text(input_for_model) combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" inputs = tokenizer( combined_input, return_tensors="pt", truncation=True, max_length=512, ).input_ids.to(device) try: print("Starting generation...") with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=256, num_beams=1, do_sample=False, early_stopping=True, ) print("Generation done.") except Exception as e: print("Generation error:", repr(e)) raise e bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) bugged_code = clear_text(bugged_code_escaped) bugged_code = fix_common_syntax_issues(bugged_code) bugged_code = format_code(bugged_code) current_code = bugged_code user_message = f"**Description**: {description}" if input_type == "original": user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" else: user_message += ( f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" ) ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" chat_history = chat_history + [ {"role": "user", "content": user_message}, {"role": "assistant", "content": ai_message}, ] return chat_history, gr.update(value=""), False def reset_interface(): global current_code, bug_counter current_code = None bug_counter = 0 return [], gr.update(value=""), True example_files = find_example_files() example_names = [ f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" for i, f in enumerate(example_files) ] def load_example(example_index): if example_index < len(example_files): return load_example_from_file(example_files[example_index]) return "", "" with gr.Blocks(title="Software-Fault Injection from NL") as demo: gr.Markdown("# 🐞 Software-Fault Injection from Natural Language") gr.Markdown( "Generate Python code with specific bugs based on a description and original code. " "The model used is **BugGen (CodeT5+ 770M, PyResBugs)**." ) with gr.Row(): with gr.Column(scale=2): description_input = gr.Textbox( label="Bug Description", placeholder="Describe the type of bug to introduce...", lines=3, ) code_input = gr.Code( label="Original Code", language="python", lines=12, ) is_first = gr.State(True) submit_btn = gr.Button("Generate Bugged Code") reset_btn = gr.Button("Start Over") gr.Markdown("### Examples") example_buttons = [gr.Button(name) for name in example_names] with gr.Column(scale=3): chat_output = gr.Chatbot( label="Conversation", height=500, ) for i, btn in enumerate(example_buttons): btn.click( fn=lambda i=i: load_example(i), outputs=[description_input, code_input], ) submit_btn.click( fn=generate_bugged_code, inputs=[description_input, code_input, chat_output, is_first], outputs=[chat_output, description_input, is_first], ) reset_btn.click( fn=reset_interface, outputs=[chat_output, description_input, is_first], ) print("Launching Gradio interface...") demo.queue(max_size=10).launch()