|
|
import tensorflow as tf
|
|
|
from tensorflow.keras.applications import EfficientNetB0
|
|
|
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
|
|
|
from tensorflow.keras.models import Model
|
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
|
|
|
|
|
|
DATASET_DIR = r"C:\Path\To\NWPU-RESISC45"
|
|
|
IMG_SIZE = (128, 128)
|
|
|
BATCH_SIZE = 16
|
|
|
NUM_CLASSES = 45
|
|
|
EPOCHS = 50
|
|
|
SEED = 42
|
|
|
|
|
|
|
|
|
class_names = sorted(os.listdir(DATASET_DIR))
|
|
|
print(f"Detected {len(class_names)} classes:", class_names)
|
|
|
|
|
|
|
|
|
data_augmentation = tf.keras.Sequential([
|
|
|
tf.keras.layers.RandomFlip("horizontal_and_vertical"),
|
|
|
tf.keras.layers.RandomRotation(0.2),
|
|
|
tf.keras.layers.RandomZoom(0.1),
|
|
|
])
|
|
|
|
|
|
|
|
|
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
|
|
DATASET_DIR,
|
|
|
validation_split=0.2,
|
|
|
subset="training",
|
|
|
seed=SEED,
|
|
|
image_size=IMG_SIZE,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
label_mode='categorical'
|
|
|
)
|
|
|
|
|
|
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
|
|
DATASET_DIR,
|
|
|
validation_split=0.2,
|
|
|
subset="validation",
|
|
|
seed=SEED,
|
|
|
image_size=IMG_SIZE,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
label_mode='categorical'
|
|
|
)
|
|
|
|
|
|
|
|
|
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))
|
|
|
|
|
|
|
|
|
AUTOTUNE = tf.data.AUTOTUNE
|
|
|
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
|
|
|
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
|
|
|
|
|
|
|
|
|
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(*IMG_SIZE, 3))
|
|
|
base_model.trainable = False
|
|
|
|
|
|
x = base_model.output
|
|
|
x = GlobalAveragePooling2D()(x)
|
|
|
x = Dropout(0.3)(x)
|
|
|
output = Dense(NUM_CLASSES, activation='softmax')(x)
|
|
|
|
|
|
model = Model(inputs=base_model.input, outputs=output)
|
|
|
|
|
|
model.compile(optimizer='adam',
|
|
|
loss='categorical_crossentropy',
|
|
|
metrics=['accuracy'])
|
|
|
|
|
|
model.summary()
|
|
|
|
|
|
|
|
|
callbacks = [
|
|
|
ModelCheckpoint("best_model.h5", save_best_only=True, monitor="val_accuracy", mode="max"),
|
|
|
EarlyStopping(patience=10, restore_best_weights=True),
|
|
|
ReduceLROnPlateau(factor=0.2, patience=5, min_lr=1e-6)
|
|
|
]
|
|
|
|
|
|
|
|
|
history = model.fit(
|
|
|
train_ds,
|
|
|
validation_data=val_ds,
|
|
|
epochs=EPOCHS,
|
|
|
callbacks=callbacks
|
|
|
)
|
|
|
|
|
|
|
|
|
base_model.trainable = True
|
|
|
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
|
|
|
loss='categorical_crossentropy',
|
|
|
metrics=['accuracy'])
|
|
|
|
|
|
history_finetune = model.fit(
|
|
|
train_ds,
|
|
|
validation_data=val_ds,
|
|
|
epochs=10,
|
|
|
callbacks=callbacks
|
|
|
)
|
|
|
|
|
|
|
|
|
plt.plot(history.history["accuracy"] + history_finetune.history["accuracy"], label="train acc")
|
|
|
plt.plot(history.history["val_accuracy"] + history_finetune.history["val_accuracy"], label="val acc")
|
|
|
plt.xlabel("Epochs")
|
|
|
plt.ylabel("Accuracy")
|
|
|
plt.legend()
|
|
|
plt.title("Training vs Validation Accuracy")
|
|
|
plt.grid()
|
|
|
plt.savefig("training_accuracy.png")
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
with open("class_names.txt", "w") as f:
|
|
|
for name in class_names:
|
|
|
f.write(name + "\n")
|
|
|
|
|
|
|