Spaces:
Sleeping
Sleeping
Delete models/spabert/notebooks
Browse files- models/spabert/notebooks/GAN-SpaBERT_pytorch.ipynb +0 -0
- models/spabert/notebooks/README.md +0 -167
- models/spabert/notebooks/Setup.ipynb +0 -0
- models/spabert/notebooks/SpaBertEmbeddingTest1.ipynb +0 -0
- models/spabert/notebooks/WHGDataset.py +0 -77
- models/spabert/notebooks/Working with SpaBERT Embedding.ipynb +0 -0
- models/spabert/notebooks/__pycache__/WHGDataset.cpython-310.pyc +0 -0
- models/spabert/notebooks/spabert-entity-linking.ipynb +0 -287
- models/spabert/notebooks/spabert-fine-tuning.ipynb +0 -262
- models/spabert/notebooks/tutorial_datasets/mlm_mem_keeppos_ep0_iter06000_0.2936.pth +0 -3
- models/spabert/notebooks/tutorial_datasets/osm_mn.csv +0 -0
- models/spabert/notebooks/tutorial_datasets/output.csv.json +0 -3
- models/spabert/notebooks/tutorial_datasets/spabert-base-uncased-finetuned-osm-mn.pth +0 -3
- models/spabert/notebooks/tutorial_datasets/spabert_osm_mn.json +0 -3
- models/spabert/notebooks/tutorial_datasets/spabert_whg_wikidata.json +0 -3
- models/spabert/notebooks/tutorial_datasets/spabert_wikidata_sampled.json +0 -3
models/spabert/notebooks/GAN-SpaBERT_pytorch.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/spabert/notebooks/README.md
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
# Tutorials for Testing and Fine-Tuning SpaBERT
|
| 2 |
-
|
| 3 |
-
This repository provides two Jupyter Notebooks for testing entity linking (one of the downstream tasks of SpaBERT) and fine-tuning procedure to train on geo-entities from other knowledge bases (e.g., [World Historical Gazetteer](https://whgazetteer.org/))
|
| 4 |
-
|
| 5 |
-
1. The first step is cloning the SpaBERT repository onto your machine. Run the following line of code to do this.
|
| 6 |
-
|
| 7 |
-
`git clone https://github.com/zekun-li/spabert.git`
|
| 8 |
-
|
| 9 |
-
2. You will need to have IPython Kernel for Jupyter installed before running the code in this tutorial. Run the following line of code to ensure ipython is installed
|
| 10 |
-
|
| 11 |
-
`pip install ipykernel`
|
| 12 |
-
|
| 13 |
-
3. Before starting the jupyter notebooks run the following lines to make sure you have all required packages:
|
| 14 |
-
|
| 15 |
-
`pip install requirements.txt`
|
| 16 |
-
|
| 17 |
-
The requirements.txt file will be located in the spabert directory.
|
| 18 |
-
```
|
| 19 |
-
-spabert
|
| 20 |
-
| - datasets
|
| 21 |
-
| - experiments
|
| 22 |
-
| - models
|
| 23 |
-
| - models
|
| 24 |
-
| - notebooks
|
| 25 |
-
| - utils
|
| 26 |
-
| - __init__.py
|
| 27 |
-
| - README.md
|
| 28 |
-
| - requirements.txt
|
| 29 |
-
| - train_mlm.py
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
## Installing Model Weights
|
| 33 |
-
|
| 34 |
-
Make sure you have git-lfs installed (https://git-lfs.com windows & mac) (https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md linux)
|
| 35 |
-
|
| 36 |
-
Please run the follow commands separately in the order to install pre-trained & fine-tuned model weights
|
| 37 |
-
|
| 38 |
-
`git lfs install`
|
| 39 |
-
|
| 40 |
-
`git clone https://huggingface.co/knowledge-computing-lab/spabert-base-uncased`
|
| 41 |
-
|
| 42 |
-
`git clone https://huggingface.co/knowledge-computing-lab/spabert-base-uncased-finetuned-osm-mn`
|
| 43 |
-
|
| 44 |
-
Once the model weight is installed, you'll see a file called `mlm_mem_keeppos_ep0_iter06000_0.2936.pth` and `spabert-base-uncased-finetuned-osm-mn.pth`
|
| 45 |
-
Move these files to the tutorial_datasets folder. After moving them, the file structure should look like this:
|
| 46 |
-
```
|
| 47 |
-
- notebooks
|
| 48 |
-
| - tutorial_datasets
|
| 49 |
-
| | - mlm_mem_keeppos_ep0_iter06000_0.2936.pth
|
| 50 |
-
| | - osm_mn.csv
|
| 51 |
-
| | - spabert_osm_mn.json
|
| 52 |
-
| | - spabert_whg_wikidata.json
|
| 53 |
-
| | - spabert_wikidata_sampled.json
|
| 54 |
-
| | - spabert-base-uncased-finetuned-osm-mn.pth
|
| 55 |
-
| - README.md
|
| 56 |
-
| - spabert-entity-linking.ipynb
|
| 57 |
-
| - spabert-fine-tuning.ipynb
|
| 58 |
-
| - WHGDataset.py
|
| 59 |
-
```
|
| 60 |
-
|
| 61 |
-
## Jupyter Notebook Descriptions
|
| 62 |
-
|
| 63 |
-
### [spabert-fine-tuning.ipynb](https://github.com/Jina-Kim/spabert/blob/main/notebooks/spabert-fine-tuning.ipynb)
|
| 64 |
-
This Jupyter Notebook provides on how to fine-tune spabert using point data from OpenStreetMap (OSM) in Minnesota. SpaBERT is pre-trained using data from California and London using OSM Point data. Instructions for pre-training your own model can be found on the spabert github
|
| 65 |
-
Here are the steps to run:
|
| 66 |
-
|
| 67 |
-
1. Define which dataset you want to use (e.g., OSM in New York or Minnesota)
|
| 68 |
-
2. Read data from csv file and construct KDTree for computing nearest neighbors
|
| 69 |
-
3. Create dataset using KDTree for fine-tuning SpaBERT using the dataset you chose
|
| 70 |
-
4. Load pre-trained model
|
| 71 |
-
5. Load dataset using the SpaBERT data loader
|
| 72 |
-
6. Train model for 1 epoch using fine-tuning model and save
|
| 73 |
-
|
| 74 |
-
### [spabert-entity-linking.ipynb](https://github.com/Jina-Kim/spabert/blob/main/notebooks/spabert-entity-linking.ipynb)
|
| 75 |
-
This Jupyter Notebook provides on how to create an entity-linking dataset and how to perform entity-linking using SpaBERT. The dataset used here is a pre-matched dataset between World Historical Gazetteer (WHG) and Wikidata. The methods used to evaluate this model will be Hits@K and Mean Reciprocal Rank (MRR)
|
| 76 |
-
Here are the steps to run:
|
| 77 |
-
|
| 78 |
-
1. Load fine-tuned model from previous Jupyter notebook
|
| 79 |
-
2. Load datasets using the WHG data loader
|
| 80 |
-
3. Calculate embeddings for whg and wikidata entities using SpaBERT
|
| 81 |
-
4. Calculate hits@1, Hits@5, Hits@10, and MRR
|
| 82 |
-
|
| 83 |
-
## Dataset Descriptions
|
| 84 |
-
|
| 85 |
-
There are two types of tutorial datasets used for fine-tuning SpaBERT, CSV and JSON files.
|
| 86 |
-
|
| 87 |
-
- CSV file - sample taken from OpenStreetMap (OSM)
|
| 88 |
-
- Minnesota State `./tutorial_datasets/osm_mn.csv`
|
| 89 |
-
|
| 90 |
-
An example data structure:
|
| 91 |
-
|
| 92 |
-
| row_id | name | latitude | longitude |
|
| 93 |
-
| ------ | ---- | -------- | --------- |
|
| 94 |
-
| 0 | Duluth | -92.1215 | 46.7729 |
|
| 95 |
-
| 1 | Green Valley | -95.757 | 44.5269 |
|
| 96 |
-
|
| 97 |
-
- JSON files - ready-to-use files for SpaBERT's data loader - [SpatialDataset](../datasets/dataset_loader.py)
|
| 98 |
-
- OSM Minnesota State `./tutorial_datasets/spabert_osm_mn.json`
|
| 99 |
-
- Generated from `./tutorial_datasets/osm_mn.csv` using spabert-fine-tuning.ipynb
|
| 100 |
-
- WHG `./tutorial_datasets/spabert_whg_wikidata.json`
|
| 101 |
-
- Geo-entities from WHG that have the link to Wikidata
|
| 102 |
-
- Wikidata `./tutorial_datasets/spabert_wikidata_sampled.json`
|
| 103 |
-
- Sampled from entities delivered by WHG. These entities have been linked between WHG and Wikidata by WHG prior to being delivered to us.
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
The file contains json objects on each line. Each json object describes the spatial context of an entity using nearby entities.
|
| 107 |
-
|
| 108 |
-
A sample json object looks like the following:
|
| 109 |
-
|
| 110 |
-
```json
|
| 111 |
-
{
|
| 112 |
-
"info":{
|
| 113 |
-
"name":"Duluth",
|
| 114 |
-
"geometry":{
|
| 115 |
-
"coordinates":[
|
| 116 |
-
46.7729,
|
| 117 |
-
-92.1215
|
| 118 |
-
]
|
| 119 |
-
}
|
| 120 |
-
},
|
| 121 |
-
"neighbor_info":{
|
| 122 |
-
"name_list":[
|
| 123 |
-
"Duluth",
|
| 124 |
-
"Chinese Peace Belle and Garden",
|
| 125 |
-
...
|
| 126 |
-
],
|
| 127 |
-
"geometry_list":[
|
| 128 |
-
{
|
| 129 |
-
"coordinates":[
|
| 130 |
-
46.7729,
|
| 131 |
-
-92.1215
|
| 132 |
-
]
|
| 133 |
-
},
|
| 134 |
-
{
|
| 135 |
-
"coordinates":[
|
| 136 |
-
46.7770,
|
| 137 |
-
-92.1241
|
| 138 |
-
]
|
| 139 |
-
},
|
| 140 |
-
...
|
| 141 |
-
]
|
| 142 |
-
}
|
| 143 |
-
}
|
| 144 |
-
```
|
| 145 |
-
|
| 146 |
-
To perform entity-linking on SpaBERT you must have a dataset structured similarly to the second dataset used for fine-tuning.
|
| 147 |
-
|
| 148 |
-
A sample json object looks like the following:
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
```json
|
| 152 |
-
{
|
| 153 |
-
"info":{
|
| 154 |
-
"name":"Duluth",
|
| 155 |
-
"geometry":{
|
| 156 |
-
"coordinates":[
|
| 157 |
-
46.7729,
|
| 158 |
-
-92.1215
|
| 159 |
-
]
|
| 160 |
-
},
|
| 161 |
-
"qid":"Q485708"
|
| 162 |
-
},
|
| 163 |
-
"neighbor_info":{
|
| 164 |
-
...
|
| 165 |
-
}
|
| 166 |
-
}
|
| 167 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/Setup.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/spabert/notebooks/SpaBertEmbeddingTest1.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/spabert/notebooks/WHGDataset.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
from torch.utils.data import Dataset
|
| 4 |
-
import json
|
| 5 |
-
import sys
|
| 6 |
-
sys.path.append("../")
|
| 7 |
-
from datasets.dataset_loader import SpatialDataset
|
| 8 |
-
from transformers import RobertaTokenizer, BertTokenizer
|
| 9 |
-
|
| 10 |
-
class WHGDataset(SpatialDataset):
|
| 11 |
-
# initializes dataset loader and converts dataset python object
|
| 12 |
-
def __init__(self, data_file_path, tokenizer=None,max_token_len = 512, distance_norm_factor = 1, spatial_dist_fill=100, sep_between_neighbors = False):
|
| 13 |
-
if tokenizer is None:
|
| 14 |
-
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 15 |
-
else:
|
| 16 |
-
self.tokenizer = tokenizer
|
| 17 |
-
self.read_data(data_file_path)
|
| 18 |
-
self.max_token_len = max_token_len
|
| 19 |
-
self.distance_norm_factor = distance_norm_factor
|
| 20 |
-
self.spatial_dist_fill = spatial_dist_fill
|
| 21 |
-
self.sep_between_neighbors = sep_between_neighbors
|
| 22 |
-
|
| 23 |
-
# returns a specific item from the dataset given an index
|
| 24 |
-
def __getitem__(self, idx):
|
| 25 |
-
return self.load_data(idx)
|
| 26 |
-
|
| 27 |
-
# returns the length of the dataset loaded
|
| 28 |
-
def __len__(self):
|
| 29 |
-
return self.len_data
|
| 30 |
-
|
| 31 |
-
def get_average_distance(self,idx):
|
| 32 |
-
line = self.data[idx]
|
| 33 |
-
line_data_dict = json.loads(line)
|
| 34 |
-
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
| 35 |
-
|
| 36 |
-
neighbor_geom_list = line_data_dict['neighbor_info']['geometry_list']
|
| 37 |
-
lat_diff = 0
|
| 38 |
-
lng_diff = 0
|
| 39 |
-
for neighbor in neighbor_geom_list:
|
| 40 |
-
coordinates = neighbor['coordinates']
|
| 41 |
-
lat_diff = lat_diff + (abs(pivot_pos[0]-coordinates[0]))
|
| 42 |
-
lng_diff = lng_diff + (abs(pivot_pos[1]-coordinates[1]))
|
| 43 |
-
avg_lat_diff = lat_diff/len(neighbor_geom_list)
|
| 44 |
-
avg_lng_diff = lng_diff/len(neighbor_geom_list)
|
| 45 |
-
return (avg_lat_diff, avg_lng_diff)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# reads dataset from given filepath, run on initilization
|
| 49 |
-
def read_data(self, data_file_path):
|
| 50 |
-
with open(data_file_path, 'r') as f:
|
| 51 |
-
data = f.readlines()
|
| 52 |
-
|
| 53 |
-
len_data = len(data)
|
| 54 |
-
self.len_data = len_data
|
| 55 |
-
self.data = data
|
| 56 |
-
|
| 57 |
-
# loads and parses dataset
|
| 58 |
-
def load_data(self, idx):
|
| 59 |
-
line = self.data[idx]
|
| 60 |
-
line_data_dict = json.loads(line)
|
| 61 |
-
|
| 62 |
-
# get pivot info
|
| 63 |
-
pivot_name = str(line_data_dict['info']['name'])
|
| 64 |
-
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
| 65 |
-
|
| 66 |
-
# get neighbor info
|
| 67 |
-
neighbor_info = line_data_dict['neighbor_info']
|
| 68 |
-
neighbor_name_list = neighbor_info['name_list']
|
| 69 |
-
neighbor_geom_list = neighbor_info['geometry_list']
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geom_list, self.spatial_dist_fill)
|
| 74 |
-
parsed_data['qid'] = line_data_dict['info']['qid']
|
| 75 |
-
|
| 76 |
-
return parsed_data
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/Working with SpaBERT Embedding.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/spabert/notebooks/__pycache__/WHGDataset.cpython-310.pyc
DELETED
|
Binary file (2.49 kB)
|
|
|
models/spabert/notebooks/spabert-entity-linking.ipynb
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"outputs": [],
|
| 8 |
-
"source": [
|
| 9 |
-
"\n",
|
| 10 |
-
"import sys\n",
|
| 11 |
-
"from transformers import BertTokenizer\n",
|
| 12 |
-
"from transformers.models.bert.modeling_bert import BertForMaskedLM\n",
|
| 13 |
-
"import torch\n",
|
| 14 |
-
"from WHGDataset import WHGDataset\n",
|
| 15 |
-
"\n",
|
| 16 |
-
"sys.path.append(\"../\")\n",
|
| 17 |
-
"from datasets.usgs_os_sample_loader import USGS_MapDataset\n",
|
| 18 |
-
"from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset\n",
|
| 19 |
-
"from models.spatial_bert_model import SpatialBertModel\n",
|
| 20 |
-
"from models.spatial_bert_model import SpatialBertConfig\n",
|
| 21 |
-
"from models.spatial_bert_model import SpatialBertForMaskedLM\n",
|
| 22 |
-
"from utils.find_closest import find_ref_closest_match, sort_ref_closest_match\n",
|
| 23 |
-
"from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv\n",
|
| 24 |
-
"from utils.baseline_utils import get_baseline_model\n",
|
| 25 |
-
"\n",
|
| 26 |
-
"\n",
|
| 27 |
-
"# load our spabert model\n",
|
| 28 |
-
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
| 29 |
-
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
| 30 |
-
" \n",
|
| 31 |
-
"config = SpatialBertConfig()\n",
|
| 32 |
-
"model = SpatialBertModel(config)\n",
|
| 33 |
-
"\n",
|
| 34 |
-
"model.to(device)\n",
|
| 35 |
-
"model.eval()\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# load pretrained weights\n",
|
| 38 |
-
"pre_trained_model=torch.load('tutorial_datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth')\n",
|
| 39 |
-
"cnt_layers = 0\n",
|
| 40 |
-
"model_keys = model.state_dict()\n",
|
| 41 |
-
"for key in model_keys:\n",
|
| 42 |
-
" if 'bert.'+ key in pre_trained_model:\n",
|
| 43 |
-
" model_keys[key] = pre_trained_model[\"bert.\"+key]\n",
|
| 44 |
-
" cnt_layers += 1\n",
|
| 45 |
-
" else:\n",
|
| 46 |
-
" print(\"No weight for\", key)\n",
|
| 47 |
-
"print(cnt_layers, 'layers loaded')\n",
|
| 48 |
-
"\n",
|
| 49 |
-
"model.load_state_dict(model_keys)"
|
| 50 |
-
]
|
| 51 |
-
},
|
| 52 |
-
{
|
| 53 |
-
"cell_type": "code",
|
| 54 |
-
"execution_count": null,
|
| 55 |
-
"metadata": {},
|
| 56 |
-
"outputs": [],
|
| 57 |
-
"source": [
|
| 58 |
-
"# load entity-linking datasets\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"sep_between_neighbors = False\n",
|
| 61 |
-
"wikidata_dict_per_map = {}\n",
|
| 62 |
-
"wikidata_dict_per_map['wikidata_emb_list'] = []\n",
|
| 63 |
-
"wikidata_dict_per_map['wikidata_qid_list'] = []\n",
|
| 64 |
-
"wikidata_dict_per_map['names'] = []\n",
|
| 65 |
-
"\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"whg_dataset = WHGDataset(\n",
|
| 68 |
-
" data_file_path = 'tutorial_datasets/spabert_whg_wikidata.json',\n",
|
| 69 |
-
" tokenizer = tokenizer,\n",
|
| 70 |
-
" max_token_len = 512, \n",
|
| 71 |
-
" distance_norm_factor = 25, \n",
|
| 72 |
-
" spatial_dist_fill=100,\n",
|
| 73 |
-
" sep_between_neighbors = sep_between_neighbors)\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"wikidata_dataset = WHGDataset(\n",
|
| 76 |
-
" data_file_path='tutorial_datasets/spabert_wikidata_sampled.json',\n",
|
| 77 |
-
" tokenizer=tokenizer,\n",
|
| 78 |
-
" max_token_len=512,\n",
|
| 79 |
-
" distance_norm_factor=50000,\n",
|
| 80 |
-
" spatial_dist_fill=20,\n",
|
| 81 |
-
" sep_between_neighbors=sep_between_neighbors)\n",
|
| 82 |
-
"\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"matched_wikid_dataset = []\n",
|
| 85 |
-
"for i in range(len(wikidata_dataset)):\n",
|
| 86 |
-
" emb = wikidata_dataset[i]\n",
|
| 87 |
-
" matched_wikid_dataset.append(emb)\n",
|
| 88 |
-
" max_dist_lng = max(emb['norm_lng_list'])\n",
|
| 89 |
-
" max_dist_lat = max(emb['norm_lat_list'])\n"
|
| 90 |
-
]
|
| 91 |
-
},
|
| 92 |
-
{
|
| 93 |
-
"cell_type": "code",
|
| 94 |
-
"execution_count": null,
|
| 95 |
-
"metadata": {},
|
| 96 |
-
"outputs": [],
|
| 97 |
-
"source": [
|
| 98 |
-
"import sys\n",
|
| 99 |
-
"sys.path.append('../')\n",
|
| 100 |
-
"from experiments.entity_matching.data_processing import request_wrapper\n",
|
| 101 |
-
"import scipy.spatial as sp\n",
|
| 102 |
-
"import numpy as np\n",
|
| 103 |
-
"## ENTITY LINKING ##\n",
|
| 104 |
-
"\n",
|
| 105 |
-
"\n",
|
| 106 |
-
"# disambigufy\n",
|
| 107 |
-
"def disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode = 'all_map', if_use_distance = True, select_indices = None): \n",
|
| 108 |
-
"\n",
|
| 109 |
-
" if select_indices is None: \n",
|
| 110 |
-
" select_indices = range(0, len(wikidata_dict_list))\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"\n",
|
| 113 |
-
" assert(candset_mode in ['all_map','per_map'])\n",
|
| 114 |
-
" wikidata_emb_list = wikidata_dict_list['wikidata_emb_list']\n",
|
| 115 |
-
" wikidata_qid_list = wikidata_dict_list['wikidata_qid_list'] \n",
|
| 116 |
-
" ret_list = []\n",
|
| 117 |
-
" for i in range(len(usgs_dataset)):\n",
|
| 118 |
-
" if (i % 1000) == 0:\n",
|
| 119 |
-
" print(\"disambigufy at \" + str((i/len(usgs_dataset))*100)+\"%\")\n",
|
| 120 |
-
" if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':\n",
|
| 121 |
-
" usgs_emb = get_spatialbert_embedding(usgs_dataset[i], model, use_distance = if_use_distance)\n",
|
| 122 |
-
" else:\n",
|
| 123 |
-
" usgs_emb = get_bert_embedding(usgs_dataset[i], model)\n",
|
| 124 |
-
" sim_matrix = 1 - sp.distance.cdist(np.array(wikidata_emb_list), np.array([usgs_emb]), 'cosine')\n",
|
| 125 |
-
" closest_match_qid = sort_ref_closest_match(sim_matrix, wikidata_qid_list)\n",
|
| 126 |
-
" #print(closest_match_qid)\n",
|
| 127 |
-
" \n",
|
| 128 |
-
" sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order\n",
|
| 129 |
-
"\n",
|
| 130 |
-
" ret_dict = dict()\n",
|
| 131 |
-
" ret_dict['pivot_name'] = usgs_dataset[i]['pivot_name']\n",
|
| 132 |
-
"\n",
|
| 133 |
-
" ret_dict['sorted_match_qid'] = [a[0] for a in closest_match_qid]\n",
|
| 134 |
-
" ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix]\n",
|
| 135 |
-
"\n",
|
| 136 |
-
" ret_list.append(ret_dict)\n",
|
| 137 |
-
"\n",
|
| 138 |
-
" return ret_list \n",
|
| 139 |
-
"\n",
|
| 140 |
-
"\n",
|
| 141 |
-
"candset_mode = 'all_map'\n",
|
| 142 |
-
"for i in range(0, len(matched_wikid_dataset)):\n",
|
| 143 |
-
" if (i % 1000) == 0:\n",
|
| 144 |
-
" print(\"processing at: \"+ str(i/len(matched_wikid_dataset)*100) + \"%\")\n",
|
| 145 |
-
" #print(matched_wikid_dataset[i])\n",
|
| 146 |
-
" entity = matched_wikid_dataset[i]\n",
|
| 147 |
-
" wikidata_emb = get_spatialbert_embedding(matched_wikid_dataset[i], model)\n",
|
| 148 |
-
" wikidata_dict_per_map['wikidata_emb_list'].append(wikidata_emb)\n",
|
| 149 |
-
" wikidata_dict_per_map['wikidata_qid_list'].append(matched_wikid_dataset[i]['qid'])\n",
|
| 150 |
-
" wikidata_dict_per_map['names'].append(wikidata_dataset[i]['pivot_name'])\n",
|
| 151 |
-
"\n",
|
| 152 |
-
"ret_list = disambiguify(model, 'spatial_bert-base', whg_dataset, wikidata_dict_per_map, candset_mode= candset_mode, if_use_distance = not False, select_indices = None)\n",
|
| 153 |
-
"write_to_csv('tutorial_datasets/', \"output.csv\", ret_list)"
|
| 154 |
-
]
|
| 155 |
-
},
|
| 156 |
-
{
|
| 157 |
-
"cell_type": "code",
|
| 158 |
-
"execution_count": null,
|
| 159 |
-
"metadata": {},
|
| 160 |
-
"outputs": [],
|
| 161 |
-
"source": [
|
| 162 |
-
"# Evaluate entity linking\n",
|
| 163 |
-
"import os\n",
|
| 164 |
-
"import pandas as pd\n",
|
| 165 |
-
"import json\n",
|
| 166 |
-
"\n",
|
| 167 |
-
"# define the ground truth directory for evaluation\n",
|
| 168 |
-
"gt_dir = os.path.abspath(\"tutorial_datasets/spabert_wikidata_sampled.json\")\n",
|
| 169 |
-
"\n",
|
| 170 |
-
"\n",
|
| 171 |
-
"# define the file where we wrote out predictions\n",
|
| 172 |
-
"prediction_path = os.path.abspath('tutorial_datasets/output.csv.json')\n",
|
| 173 |
-
"\n",
|
| 174 |
-
"\n",
|
| 175 |
-
"# define ground truth dictionary\n",
|
| 176 |
-
"gt_dict = dict()\n",
|
| 177 |
-
"\n",
|
| 178 |
-
"with open(gt_dir) as f:\n",
|
| 179 |
-
" data = f.readlines()\n",
|
| 180 |
-
" for line in data:\n",
|
| 181 |
-
" d = json.loads(line)\n",
|
| 182 |
-
" gt_dict[d['info']['name']] = d['info']['qid']\n",
|
| 183 |
-
"\n",
|
| 184 |
-
"\n",
|
| 185 |
-
"\n",
|
| 186 |
-
"rank_list = []\n",
|
| 187 |
-
"hits_at_1 = 0\n",
|
| 188 |
-
"hits_at_5 = 0\n",
|
| 189 |
-
"hits_at_10 = 0\n",
|
| 190 |
-
"out_dict = {'title':[],'rank':[]}\n",
|
| 191 |
-
"\n",
|
| 192 |
-
"with open(prediction_path) as f:\n",
|
| 193 |
-
" data = f.readlines()\n",
|
| 194 |
-
" for line in data:\n",
|
| 195 |
-
" pred_dict = json.loads(line)\n",
|
| 196 |
-
" pivot_name = pred_dict['pivot_name']\n",
|
| 197 |
-
" sorted_matched_uri = pred_dict['sorted_match_qid']\n",
|
| 198 |
-
" sorted_sim_matrix = pred_dict['sorted_sim_matrix']\n",
|
| 199 |
-
" if pivot_name in gt_dict:\n",
|
| 200 |
-
" gt_uri = gt_dict[pivot_name]\n",
|
| 201 |
-
" rank = sorted_matched_uri.index(gt_uri) +1\n",
|
| 202 |
-
" if rank == 1:\n",
|
| 203 |
-
" hits_at_1 += 1\n",
|
| 204 |
-
" if rank <= 5:\n",
|
| 205 |
-
" hits_at_5 += 1\n",
|
| 206 |
-
" if rank <= 10:\n",
|
| 207 |
-
" hits_at_10 +=1\n",
|
| 208 |
-
" rank_list.append(rank)\n",
|
| 209 |
-
" out_dict['title'].append(pivot_name)\n",
|
| 210 |
-
" out_dict['rank'].append(rank)\n",
|
| 211 |
-
"\n",
|
| 212 |
-
"hits_at_1 = hits_at_1/len(rank_list)\n",
|
| 213 |
-
"hits_at_5 = hits_at_5/len(rank_list)\n",
|
| 214 |
-
"hits_at_10 = hits_at_10/len(rank_list)\n",
|
| 215 |
-
"\n",
|
| 216 |
-
"print(hits_at_1)\n",
|
| 217 |
-
"print(hits_at_5)\n",
|
| 218 |
-
"print(hits_at_10)\n",
|
| 219 |
-
"\n",
|
| 220 |
-
"out_df = pd.DataFrame(out_dict)\n",
|
| 221 |
-
"out_df\n",
|
| 222 |
-
" \n",
|
| 223 |
-
"\n"
|
| 224 |
-
]
|
| 225 |
-
},
|
| 226 |
-
{
|
| 227 |
-
"attachments": {},
|
| 228 |
-
"cell_type": "markdown",
|
| 229 |
-
"metadata": {},
|
| 230 |
-
"source": [
|
| 231 |
-
"Mean Reciprocal Rank is a statistical measure for evaluating processes that produce a list of possible responses of a query in order of probability of correctness.\n",
|
| 232 |
-
"\n",
|
| 233 |
-
"First we obtain the rank from the ranked list shown above.\n",
|
| 234 |
-
"\n",
|
| 235 |
-
"Next we calculate the reciprocal rank for each rank. The reciprocal is the inverse of the rank. So for a rank of 1 the recprocal rank would be 1/1, for a rank of 2 the reciprocal rank would be 1/2.\n",
|
| 236 |
-
"\n",
|
| 237 |
-
"The mean reciprocal rank is the average of the reciprocal ranks. \n",
|
| 238 |
-
"\n",
|
| 239 |
-
"This measure gives us a general conceptualization of how well our model predicts entities based on their embeddings.\n",
|
| 240 |
-
"\n",
|
| 241 |
-
"An in-depth description of Mean Reciprocal Rank can be found here https://en.wikipedia.org/wiki/Mean_reciprocal_rank\n",
|
| 242 |
-
"\n",
|
| 243 |
-
"An import thing to keep in mind when caclulating mean reciprocal rank is that it tends to inversely scale with your candidate set size\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"Our candidate set is has a length of 4624 "
|
| 246 |
-
]
|
| 247 |
-
},
|
| 248 |
-
{
|
| 249 |
-
"cell_type": "code",
|
| 250 |
-
"execution_count": null,
|
| 251 |
-
"metadata": {},
|
| 252 |
-
"outputs": [],
|
| 253 |
-
"source": [
|
| 254 |
-
"# calculating the mean reciprocal rank (MRR)\n",
|
| 255 |
-
"import numpy as np\n",
|
| 256 |
-
"\n",
|
| 257 |
-
"reciprocal_list = [1./rank for rank in rank_list]\n",
|
| 258 |
-
"\n",
|
| 259 |
-
"MRR = np.mean(reciprocal_list)\n",
|
| 260 |
-
"\n",
|
| 261 |
-
"print(MRR)\n"
|
| 262 |
-
]
|
| 263 |
-
}
|
| 264 |
-
],
|
| 265 |
-
"metadata": {
|
| 266 |
-
"kernelspec": {
|
| 267 |
-
"display_name": "ucgis23workshop",
|
| 268 |
-
"language": "python",
|
| 269 |
-
"name": "python3"
|
| 270 |
-
},
|
| 271 |
-
"language_info": {
|
| 272 |
-
"codemirror_mode": {
|
| 273 |
-
"name": "ipython",
|
| 274 |
-
"version": 3
|
| 275 |
-
},
|
| 276 |
-
"file_extension": ".py",
|
| 277 |
-
"mimetype": "text/x-python",
|
| 278 |
-
"name": "python",
|
| 279 |
-
"nbconvert_exporter": "python",
|
| 280 |
-
"pygments_lexer": "ipython3",
|
| 281 |
-
"version": "3.11.3"
|
| 282 |
-
},
|
| 283 |
-
"orig_nbformat": 4
|
| 284 |
-
},
|
| 285 |
-
"nbformat": 4,
|
| 286 |
-
"nbformat_minor": 2
|
| 287 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/spabert-fine-tuning.ipynb
DELETED
|
@@ -1,262 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"outputs": [],
|
| 8 |
-
"source": [
|
| 9 |
-
"import json\n",
|
| 10 |
-
"import pandas as pd\n",
|
| 11 |
-
"\n",
|
| 12 |
-
"# LOCATION OF THE OSM DATA FOR FINE-TUNING\n",
|
| 13 |
-
"data = 'tutorial_datasets/osm_mn.csv'\n"
|
| 14 |
-
]
|
| 15 |
-
},
|
| 16 |
-
{
|
| 17 |
-
"cell_type": "code",
|
| 18 |
-
"execution_count": null,
|
| 19 |
-
"metadata": {},
|
| 20 |
-
"outputs": [],
|
| 21 |
-
"source": [
|
| 22 |
-
"## CONSTRUCT DATASET FOR FINE TUNING ##\n",
|
| 23 |
-
"\n",
|
| 24 |
-
"# Read data from .csv data file\n",
|
| 25 |
-
"\n",
|
| 26 |
-
"state_frame = pd.read_csv(data)\n",
|
| 27 |
-
"\n",
|
| 28 |
-
"\n",
|
| 29 |
-
"# construct list of names and coordinates from data\n",
|
| 30 |
-
"name_list = []\n",
|
| 31 |
-
"coordinate_list = []\n",
|
| 32 |
-
"for i, item in state_frame.iterrows():\n",
|
| 33 |
-
" name = item[1]\n",
|
| 34 |
-
" lat = item[2]\n",
|
| 35 |
-
" lng =item[3]\n",
|
| 36 |
-
" name_list.append(name)\n",
|
| 37 |
-
" coordinate_list.append([lng,lat])\n",
|
| 38 |
-
"\n",
|
| 39 |
-
"\n",
|
| 40 |
-
"# construct KDTree out of coordinates list for when we make the neighbor lists\n",
|
| 41 |
-
"import scipy.spatial as scp\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"ordered_neighbor_coordinate_list = scp.KDTree(coordinate_list)"
|
| 44 |
-
]
|
| 45 |
-
},
|
| 46 |
-
{
|
| 47 |
-
"cell_type": "code",
|
| 48 |
-
"execution_count": null,
|
| 49 |
-
"metadata": {},
|
| 50 |
-
"outputs": [],
|
| 51 |
-
"source": [
|
| 52 |
-
"state_frame"
|
| 53 |
-
]
|
| 54 |
-
},
|
| 55 |
-
{
|
| 56 |
-
"cell_type": "code",
|
| 57 |
-
"execution_count": null,
|
| 58 |
-
"metadata": {},
|
| 59 |
-
"outputs": [],
|
| 60 |
-
"source": [
|
| 61 |
-
"\n",
|
| 62 |
-
"# Get top 20 nearest neighbors for each entity in dataset\n",
|
| 63 |
-
"with open('tutorial_datasets/SPABERT_finetuning_data.json', 'w') as out_f:\n",
|
| 64 |
-
" for i, item in state_frame.iterrows():\n",
|
| 65 |
-
" name = item[1]\n",
|
| 66 |
-
" lat = item[2]\n",
|
| 67 |
-
" lng = item[3]\n",
|
| 68 |
-
" coordinates = [lng,lat]\n",
|
| 69 |
-
"\n",
|
| 70 |
-
" _, nearest_neighbors_idx = ordered_neighbor_coordinate_list.query([coordinates], k=21)\n",
|
| 71 |
-
"\n",
|
| 72 |
-
" # we want to store their names and coordinates\n",
|
| 73 |
-
"\n",
|
| 74 |
-
" nearest_neighbors_name = []\n",
|
| 75 |
-
" nearest_neighbors_coords = []\n",
|
| 76 |
-
" \n",
|
| 77 |
-
" # iterate over nearest neighbors list\n",
|
| 78 |
-
" for idx in nearest_neighbors_idx[0]:\n",
|
| 79 |
-
" # get name and coordinate of neighbor\n",
|
| 80 |
-
" neighbor_name = name_list[idx]\n",
|
| 81 |
-
" neighbor_coords = coordinate_list[idx]\n",
|
| 82 |
-
" nearest_neighbors_name.append(neighbor_name)\n",
|
| 83 |
-
" nearest_neighbors_coords.append({\"coordinates\": neighbor_coords})\n",
|
| 84 |
-
" \n",
|
| 85 |
-
" # construct neighbor info dictionary object for SpaBERT embedding construction\n",
|
| 86 |
-
" neighbor_info = {\"name_list\":nearest_neighbors_name, \"geometry_list\":nearest_neighbors_coords}\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"\n",
|
| 89 |
-
" # construct full dictionary object for SpaBERT embedding construction\n",
|
| 90 |
-
" place = {\"info\":{\"name\":name, \"geometry\":{\"coordinates\": coordinates}}, \"neighbor_info\":neighbor_info}\n",
|
| 91 |
-
"\n",
|
| 92 |
-
" out_f.write(json.dumps(place))\n",
|
| 93 |
-
" out_f.write('\\n')"
|
| 94 |
-
]
|
| 95 |
-
},
|
| 96 |
-
{
|
| 97 |
-
"cell_type": "code",
|
| 98 |
-
"execution_count": null,
|
| 99 |
-
"metadata": {},
|
| 100 |
-
"outputs": [],
|
| 101 |
-
"source": [
|
| 102 |
-
"### FINE-TUNE SPABERT\n",
|
| 103 |
-
"import sys\n",
|
| 104 |
-
"from transformers.models.bert.modeling_bert import BertForMaskedLM\n",
|
| 105 |
-
"from transformers import BertTokenizer\n",
|
| 106 |
-
"sys.path.append(\"../\")\n",
|
| 107 |
-
"from models.spatial_bert_model import SpatialBertConfig\n",
|
| 108 |
-
"from utils.common_utils import load_spatial_bert_pretrained_weights\n",
|
| 109 |
-
"from models.spatial_bert_model import SpatialBertForMaskedLM\n",
|
| 110 |
-
"\n",
|
| 111 |
-
"# load dataset we just created\n",
|
| 112 |
-
"\n",
|
| 113 |
-
"dataset = 'tutorial_datasets/SPABERT_finetuning_data.json'\n",
|
| 114 |
-
"\n",
|
| 115 |
-
"# load pre-trained spabert model\n",
|
| 116 |
-
"\n",
|
| 117 |
-
"pretrained_model = 'tutorial_datasets/mlm_mem_keeppos_ep0_iter06000_0.2936.pth'\n",
|
| 118 |
-
"\n",
|
| 119 |
-
"\n",
|
| 120 |
-
"# load bert model and tokenizer as well as the SpaBERT config\n",
|
| 121 |
-
"bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')\n",
|
| 122 |
-
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
| 123 |
-
"config = SpatialBertConfig()"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"cell_type": "code",
|
| 128 |
-
"execution_count": null,
|
| 129 |
-
"metadata": {},
|
| 130 |
-
"outputs": [],
|
| 131 |
-
"source": [
|
| 132 |
-
"# load pre-trained spabert model\n",
|
| 133 |
-
"import torch\n",
|
| 134 |
-
"model = SpatialBertForMaskedLM(config)\n",
|
| 135 |
-
"\n",
|
| 136 |
-
"model.load_state_dict(bert_model.state_dict() , strict = False) \n",
|
| 137 |
-
"\n",
|
| 138 |
-
"pre_trained_model = torch.load(pretrained_model)\n",
|
| 139 |
-
"\n",
|
| 140 |
-
"model_keys = model.state_dict()\n",
|
| 141 |
-
"cnt_layers = 0\n",
|
| 142 |
-
"for key in model_keys:\n",
|
| 143 |
-
" if key in pre_trained_model:\n",
|
| 144 |
-
" model_keys[key] = pre_trained_model[key]\n",
|
| 145 |
-
" cnt_layers += 1\n",
|
| 146 |
-
" else:\n",
|
| 147 |
-
" print(\"No weight for\", key)\n",
|
| 148 |
-
"print(cnt_layers, 'layers loaded')\n",
|
| 149 |
-
"\n",
|
| 150 |
-
"model.load_state_dict(model_keys)"
|
| 151 |
-
]
|
| 152 |
-
},
|
| 153 |
-
{
|
| 154 |
-
"cell_type": "code",
|
| 155 |
-
"execution_count": null,
|
| 156 |
-
"metadata": {},
|
| 157 |
-
"outputs": [],
|
| 158 |
-
"source": [
|
| 159 |
-
"from datasets.osm_sample_loader import PbfMapDataset\n",
|
| 160 |
-
"from torch.utils.data import DataLoader\n",
|
| 161 |
-
"# load fine-tning dataset with data loader\n",
|
| 162 |
-
"\n",
|
| 163 |
-
"fine_tune_dataset = PbfMapDataset(data_file_path = dataset, \n",
|
| 164 |
-
" tokenizer = tokenizer, \n",
|
| 165 |
-
" max_token_len = 300, \n",
|
| 166 |
-
" distance_norm_factor = 0.0001, \n",
|
| 167 |
-
" spatial_dist_fill = 20, \n",
|
| 168 |
-
" with_type = False,\n",
|
| 169 |
-
" sep_between_neighbors = False, \n",
|
| 170 |
-
" label_encoder = None,\n",
|
| 171 |
-
" mode = None)\n",
|
| 172 |
-
"#initialize data loader\n",
|
| 173 |
-
"train_loader = DataLoader(fine_tune_dataset, batch_size=12, num_workers=5, shuffle=False, pin_memory=True, drop_last=True)\n",
|
| 174 |
-
"\n"
|
| 175 |
-
]
|
| 176 |
-
},
|
| 177 |
-
{
|
| 178 |
-
"cell_type": "code",
|
| 179 |
-
"execution_count": null,
|
| 180 |
-
"metadata": {},
|
| 181 |
-
"outputs": [],
|
| 182 |
-
"source": [
|
| 183 |
-
"import torch\n",
|
| 184 |
-
"# cast our loaded model to a gpu if one is available, otherwise use the cpu\n",
|
| 185 |
-
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
| 186 |
-
"model.to(device)\n",
|
| 187 |
-
"\n",
|
| 188 |
-
"# set model to training mode\n",
|
| 189 |
-
"model.train()"
|
| 190 |
-
]
|
| 191 |
-
},
|
| 192 |
-
{
|
| 193 |
-
"cell_type": "code",
|
| 194 |
-
"execution_count": null,
|
| 195 |
-
"metadata": {},
|
| 196 |
-
"outputs": [],
|
| 197 |
-
"source": [
|
| 198 |
-
"### FINE TUNING PROCEDURE ###\n",
|
| 199 |
-
"from tqdm import tqdm \n",
|
| 200 |
-
"from transformers import AdamW\n",
|
| 201 |
-
"# initialize optimizer\n",
|
| 202 |
-
"optim = AdamW(model.parameters(), lr = 5e-5)\n",
|
| 203 |
-
"\n",
|
| 204 |
-
"# setup loop with TQDM and dataloader\n",
|
| 205 |
-
"epoch = tqdm(train_loader, leave=True)\n",
|
| 206 |
-
"iter = 0\n",
|
| 207 |
-
"for batch in epoch:\n",
|
| 208 |
-
" # initialize calculated gradients from previous step\n",
|
| 209 |
-
" optim.zero_grad()\n",
|
| 210 |
-
"\n",
|
| 211 |
-
" # pull all tensor batches required for training\n",
|
| 212 |
-
" input_ids = batch['masked_input'].to(device)\n",
|
| 213 |
-
" attention_mask = batch['attention_mask'].to(device)\n",
|
| 214 |
-
" position_list_x = batch['norm_lng_list'].to(device)\n",
|
| 215 |
-
" position_list_y = batch['norm_lat_list'].to(device)\n",
|
| 216 |
-
" sent_position_ids = batch['sent_position_ids'].to(device)\n",
|
| 217 |
-
"\n",
|
| 218 |
-
" labels = batch['pseudo_sentence'].to(device)\n",
|
| 219 |
-
"\n",
|
| 220 |
-
" # get outputs of model\n",
|
| 221 |
-
" outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,\n",
|
| 222 |
-
" position_list_x = position_list_x, position_list_y = position_list_y, labels = labels)\n",
|
| 223 |
-
" \n",
|
| 224 |
-
"\n",
|
| 225 |
-
" # calculate loss\n",
|
| 226 |
-
" loss = outputs.loss\n",
|
| 227 |
-
"\n",
|
| 228 |
-
" # perform backpropigation\n",
|
| 229 |
-
" loss.backward()\n",
|
| 230 |
-
"\n",
|
| 231 |
-
" optim.step()\n",
|
| 232 |
-
" epoch.set_postfix({'loss':loss.item()})\n",
|
| 233 |
-
"\n",
|
| 234 |
-
"\n",
|
| 235 |
-
" iter += 1\n",
|
| 236 |
-
"torch.save(model.state_dict(), \"tutorial_datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth\")\n"
|
| 237 |
-
]
|
| 238 |
-
}
|
| 239 |
-
],
|
| 240 |
-
"metadata": {
|
| 241 |
-
"kernelspec": {
|
| 242 |
-
"display_name": "base",
|
| 243 |
-
"language": "python",
|
| 244 |
-
"name": "python3"
|
| 245 |
-
},
|
| 246 |
-
"language_info": {
|
| 247 |
-
"codemirror_mode": {
|
| 248 |
-
"name": "ipython",
|
| 249 |
-
"version": 3
|
| 250 |
-
},
|
| 251 |
-
"file_extension": ".py",
|
| 252 |
-
"mimetype": "text/x-python",
|
| 253 |
-
"name": "python",
|
| 254 |
-
"nbconvert_exporter": "python",
|
| 255 |
-
"pygments_lexer": "ipython3",
|
| 256 |
-
"version": "3.11.3"
|
| 257 |
-
},
|
| 258 |
-
"orig_nbformat": 4
|
| 259 |
-
},
|
| 260 |
-
"nbformat": 4,
|
| 261 |
-
"nbformat_minor": 2
|
| 262 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/mlm_mem_keeppos_ep0_iter06000_0.2936.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e591f9d4798bee6d0deb59dcbbaefb31f08fdc5c751b81e4b52b95ddca766b71
|
| 3 |
-
size 531897899
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/osm_mn.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/spabert/notebooks/tutorial_datasets/output.csv.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5c5b0c617f0cf93c320a8d8a38a3af7f092ad54eebfeca866d2905a03bcae6f8
|
| 3 |
-
size 37566110
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/spabert-base-uncased-finetuned-osm-mn.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7046e57275530ee40ec10a80ecb67977e6f6530dc935c6abf77cf1d56c3d0f9a
|
| 3 |
-
size 531904817
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/spabert_osm_mn.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0665091faef52166a44c6ef253af85c99e30f7a63150b4542d42768793a088f6
|
| 3 |
-
size 65595132
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/spabert_whg_wikidata.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bba735c975231c42f467285b4f17ce4ba58262557f2769b89c863b7f37302209
|
| 3 |
-
size 52811876
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/notebooks/tutorial_datasets/spabert_wikidata_sampled.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d13476f6583e96ebc7272af910f99decc062b4053f92cc927837c49a777e6e86
|
| 3 |
-
size 27841961
|
|
|
|
|
|
|
|
|
|
|
|