""" RadFig VQA Image Filtering Model - Inference Script Classifies medical images as suitable/unsuitable for VQA tasks. """ import os import torch import torch.nn as nn import timm import cv2 import numpy as np import pandas as pd from PIL import Image from torch.utils.data import Dataset, DataLoader from albumentations import Compose, Resize, Normalize from albumentations.pytorch import ToTensorV2 from tqdm import tqdm class Config: """Configuration for inference""" model_name = "tf_efficientnetv2_s.in21k_ft_in1k" size = 512 batch_size = 32 num_workers = 4 target_size = 1 n_fold = 5 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class TestDataset(Dataset): """Dataset for inference""" def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] # Load image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image: {image_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.transform: augmented = self.transform(image=image) image = augmented['image'] return image def get_transforms(): """Get inference transforms""" return Compose([ Resize(Config.size, Config.size), Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ToTensorV2(), ]) class RadFigClassifier: """RadFig VQA Image Filtering Classifier""" def __init__(self, model_dir="models"): self.config = Config() self.model_dir = model_dir self.device = self.config.device self.model = None self.states = [] # Load model states self._load_model_states() def _load_model_states(self): """Load all fold model states""" self.states = [] for fold in range(self.config.n_fold): model_path = os.path.join( self.model_dir, f"{self.config.model_name}_fold{fold}_best_loss.pth" ) if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") state = torch.load(model_path, map_location=self.device) self.states.append(state) print(f"Loaded {len(self.states)} model states from {self.model_dir}") def _create_model(self): """Create model architecture""" model = timm.create_model( model_name=self.config.model_name, num_classes=self.config.target_size, pretrained=False ) return model.to(self.device) def predict_batch(self, image_paths, return_probabilities=True): """ Predict on a batch of images Args: image_paths (list): List of image file paths return_probabilities (bool): If True, return probabilities. If False, return binary predictions. Returns: numpy.ndarray: Predictions (probabilities or binary) """ # Create dataset and dataloader dataset = TestDataset(image_paths, transform=get_transforms()) dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=False, num_workers=self.config.num_workers, pin_memory=True ) # Create model model = self._create_model() all_predictions = [] # Inference loop with torch.no_grad(): for images in tqdm(dataloader, desc="Predicting"): images = images.to(self.device) # Ensemble predictions across all folds fold_predictions = [] for state in self.states: model.load_state_dict(state['model']) model.eval() outputs = model(images) probabilities = torch.sigmoid(outputs).cpu().numpy() fold_predictions.append(probabilities) # Average predictions across folds avg_predictions = np.mean(fold_predictions, axis=0) all_predictions.append(avg_predictions) # Concatenate all predictions predictions = np.concatenate(all_predictions, axis=0).flatten() if return_probabilities: return predictions else: return (predictions > 0.5).astype(int) def predict_single(self, image_path, return_probability=True): """ Predict on a single image Args: image_path (str): Path to image file return_probability (bool): If True, return probability. If False, return binary prediction. Returns: float or int: Prediction """ predictions = self.predict_batch([image_path], return_probabilities=return_probability) return predictions[0] def predict_directory(self, directory_path, output_csv=None, return_probabilities=True): """ Predict on all images in a directory Args: directory_path (str): Path to directory containing images output_csv (str, optional): Path to save results as CSV return_probabilities (bool): If True, return probabilities. If False, return binary predictions. Returns: pandas.DataFrame: Results with image paths and predictions """ # Get all image files image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} image_paths = [] for filename in os.listdir(directory_path): if any(filename.lower().endswith(ext) for ext in image_extensions): image_paths.append(os.path.join(directory_path, filename)) if not image_paths: raise ValueError(f"No image files found in {directory_path}") print(f"Found {len(image_paths)} images in {directory_path}") # Get predictions predictions = self.predict_batch(image_paths, return_probabilities=return_probabilities) # Create results dataframe results = pd.DataFrame({ 'image_path': image_paths, 'filename': [os.path.basename(path) for path in image_paths], 'prediction': predictions, 'suitable_for_vqa': predictions > 0.9 if return_probabilities else predictions.astype(bool) }) # Sort by filename for consistency results = results.sort_values('filename').reset_index(drop=True) # Save to CSV if requested if output_csv: results.to_csv(output_csv, index=False) print(f"Results saved to {output_csv}") return results def main(): """Example usage""" import argparse parser = argparse.ArgumentParser(description="RadFig VQA Image Filtering Inference") parser.add_argument("--input", required=True, help="Input image file or directory") parser.add_argument("--models", default="models", help="Directory containing model files") parser.add_argument("--output", help="Output CSV file (for directory input)") parser.add_argument("--binary", action="store_true", help="Return binary predictions instead of probabilities") args = parser.parse_args() # Initialize classifier classifier = RadFigClassifier(model_dir=args.models) if os.path.isfile(args.input): # Single image prediction prediction = classifier.predict_single( args.input, return_probability=not args.binary ) if args.binary: result = "suitable" if prediction else "not suitable" print(f"Image: {args.input}") print(f"Prediction: {result} for VQA") else: print(f"Image: {args.input}") print(f"Probability suitable for VQA: {prediction:.4f}") print(f"Classification: {'suitable' if prediction > 0.9 else 'not suitable'}") elif os.path.isdir(args.input): # Directory prediction results = classifier.predict_directory( args.input, output_csv=args.output, return_probabilities=not args.binary ) # Print summary if args.binary: suitable_count = results['suitable_for_vqa'].sum() else: suitable_count = (results['prediction'] > 0.9).sum() total_count = len(results) print(f"\nSummary:") print(f"Total images: {total_count}") print(f"Suitable for VQA: {suitable_count}") print(f"Not suitable for VQA: {total_count - suitable_count}") print(f"Percentage suitable: {suitable_count/total_count*100:.1f}%") else: print(f"Error: {args.input} is not a valid file or directory") if __name__ == "__main__": main()