Efficientnet-B0 / train_resisc45.py
Arshul26's picture
Upload folder using huggingface_hub
41223d1 verified
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import matplotlib.pyplot as plt
import os
# Paths
DATASET_DIR = r"C:\Path\To\NWPU-RESISC45" # Path to your Dataset
IMG_SIZE = (128, 128)
BATCH_SIZE = 16
NUM_CLASSES = 45
EPOCHS = 50
SEED = 42
# Sanity check: verify class folders
class_names = sorted(os.listdir(DATASET_DIR))
print(f"Detected {len(class_names)} classes:", class_names)
# Augmentation for training set
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.RandomRotation(0.2),
tf.keras.layers.RandomZoom(0.1),
])
# Load Datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_DIR,
validation_split=0.2,
subset="training",
seed=SEED,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE,
label_mode='categorical'
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_DIR,
validation_split=0.2,
subset="validation",
seed=SEED,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE,
label_mode='categorical'
)
# Apply augmentation only to training
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))
# Prefetch
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
# Build Model
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(*IMG_SIZE, 3))
base_model.trainable = False # Freeze base model initially
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.3)(x)
output = Dense(NUM_CLASSES, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# Callbacks
callbacks = [
ModelCheckpoint("best_model.h5", save_best_only=True, monitor="val_accuracy", mode="max"),
EarlyStopping(patience=10, restore_best_weights=True),
ReduceLROnPlateau(factor=0.2, patience=5, min_lr=1e-6)
]
# Train
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS,
callbacks=callbacks
)
# Optional: Fine-tune after warmup (unfreeze base model)
base_model.trainable = True
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), # Lower LR for fine-tuning
loss='categorical_crossentropy',
metrics=['accuracy'])
history_finetune = model.fit(
train_ds,
validation_data=val_ds,
epochs=10,
callbacks=callbacks
)
# Plotting
plt.plot(history.history["accuracy"] + history_finetune.history["accuracy"], label="train acc")
plt.plot(history.history["val_accuracy"] + history_finetune.history["val_accuracy"], label="val acc")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training vs Validation Accuracy")
plt.grid()
plt.savefig("training_accuracy.png")
plt.show()
# Save class names
with open("class_names.txt", "w") as f:
for name in class_names:
f.write(name + "\n")