diff --git a/bsmetadata/preprocessing_scripts/download_wiki_dump.sh b/bsmetadata/preprocessing_scripts/download_wiki_dump.sh new file mode 100644 index 00000000..db01be0f --- /dev/null +++ b/bsmetadata/preprocessing_scripts/download_wiki_dump.sh @@ -0,0 +1,12 @@ + + +out_dir=${1:-bsmetadata/preprocessing_data} # default director: preprocessing_data + +## Clone the huggingface dataset repo containing wiki dump +mkdir -p "$out_dir" +HUB_REPO_NAME=bs-modeling-metadata/wiki_dump +git clone https://huggingface.co/datasets/${HUB_REPO_NAME} $out_dir/wiki_dump + + +## Downloading nltk punkt to be used in sentence tokenizer +python -m nltk.downloader 'punkt' \ No newline at end of file diff --git a/bsmetadata/preprocessing_tools/website_desc_utils.py b/bsmetadata/preprocessing_tools/website_desc_utils.py new file mode 100644 index 00000000..942bab71 --- /dev/null +++ b/bsmetadata/preprocessing_tools/website_desc_utils.py @@ -0,0 +1,42 @@ +import re +from collections import defaultdict +from typing import Optional + +import nltk +from wikipedia2vec.dump_db import DumpDB + + +class WebsiteDescUtils: + def __init__(self, path_wiki_db) -> None: + self.cache = defaultdict(str) + self.wiki_dump_db = DumpDB(path_wiki_db) + self.redirects_map = { + key.lower(): value for key, value in self.wiki_dump_db.redirects() + } # loading all redirect information: takes ~10s + + def fetch_wikipedia_title_from_keyword(self, keyword: str) -> str: + title = self.redirects_map.get( + keyword, keyword.split(".")[0].capitalize() + ) # fallback to default for cases where domain is not recognized. We'll try to hit the db with the exact keyword directly (e.g. rightmove.com -> Rightmove) Capitalizing since wikipedia titles are so + return title + + def fetch_wikipedia_description_for_title(self, title: str) -> Optional: + try: + text = self.wiki_dump_db.get_paragraphs(title)[0].text + text = re.sub(r"\((?:[^)(]|\([^)(]*\))*\)", "", text) + text = nltk.sent_tokenize(text)[0] # Picking the first sentence + except Exception: + return None + return text + + def extract_wiki_desc(self, keyword: str) -> Optional: + + title = self.fetch_wikipedia_title_from_keyword(keyword) + desc = self.fetch_wikipedia_description_for_title(title) + return desc + + def fetch_website_description_from_keyword(self, keyword: str) -> Optional: + if not self.cache[keyword]: + self.cache[keyword] = self.extract_wiki_desc(keyword) + + return self.cache[keyword] diff --git a/bsmetadata/preprocessing_utils.py b/bsmetadata/preprocessing_utils.py index fffb06a0..85fcfd90 100644 --- a/bsmetadata/preprocessing_utils.py +++ b/bsmetadata/preprocessing_utils.py @@ -23,6 +23,7 @@ from REL.ner import load_flair_ner from REL.utils import process_results +from bsmetadata.preprocessing_tools.website_desc_utils import WebsiteDescUtils from bsmetadata.vendor.dateutil.src.dateutil.parser import ParserError, parse @@ -42,6 +43,11 @@ def parse_date(path): return None +def fetch_keyword_from_url(url: str) -> str: # e.g http://www.californialandcan.org/Plumas -> californialandcan.org + domain = urlsplit(url).netloc + return domain.replace("www.", "") + + def remove_improbable_date(x): if x is not None and (x.year < 1983 or x.year > 2021): return None @@ -88,6 +94,38 @@ def _extract_timestamp_from_url(self, url: str) -> Optional[str]: return date +class WebsiteDescPreprocessor(MetadataPreprocessor): + """Metadata preprocessor for adding website description based on URLs.""" + + def __init__(self, path_wiki_db: str = "../preprocessing_data/wiki_dump/wiki_en_dump_db") -> None: + self.website_utils = WebsiteDescUtils(path_wiki_db) + super().__init__() + + def preprocess(self, examples: Dict[str, List]) -> Dict[str, List]: + + metadata_list = examples["metadata"] + + # Iterate through the metadata associated with all examples in this batch. + for metadata in metadata_list: + # Get the URL associated with this example. + urls = [md["value"] for md in metadata if md["key"] == "url"] + + if not urls: + continue + + # Try to extract a website description from the given URL and add it to the metadata. + website_description = self._extract_website_desc_from_url(urls[0]) + + if website_description: + metadata.append({"key": "website_description", "type": "global", "value": website_description}) + return examples + + def _extract_website_desc_from_url(self, url: str) -> Optional: + + keyword = fetch_keyword_from_url(url) + return self.website_utils.fetch_website_description_from_keyword(keyword) + + class EntityPreprocessor(MetadataPreprocessor): """Metadata preprocessor for adding entity information.""" diff --git a/requirements.txt b/requirements.txt index 62be06ed..38f5ad91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ wandb>=0.10.32,<1 # pip will likely update it to 0.12.1, but it is probably ok transformers>=4.6.0,<5 # pip will likely update it to 4.10.0, but it is probably ok and good for bugfixes. accelerate>=0.4.0,<1 # We may want to use 0.5.0 in the near future datasets[streaming]>=1.11.0,<2 +wikipedia2vec==1.0.5 +nltk==3.6.5 diff --git a/setup.py b/setup.py index 40f60afe..2bec9d37 100644 --- a/setup.py +++ b/setup.py @@ -21,5 +21,6 @@ def req_file(filename): install_requires=install_requires, extras_require={ "entity_preprocessing": ["REL @ git+https://github.com/manandey/REL.git#egg=REL"], + "website_description_preprocessing": ["wikipedia2vec==1.0.5", "nltk==3.6.5"], }, ) diff --git a/tests/mocks/mock_dump_db.py b/tests/mocks/mock_dump_db.py new file mode 100644 index 00000000..a500c65b --- /dev/null +++ b/tests/mocks/mock_dump_db.py @@ -0,0 +1,32 @@ +from typing import List + + +class MockParagraph: + def __init__(self, text): + self.text = text + + +class MockDumpDB: + def __init__(self, db_file) -> None: + self.db_file = db_file + self.redirect_info = [("xyz.com", "XYZ"), ("test.com", "Test"), ("test_key", "Test Key")] + self.paragraphs_map = { + "XYZ": [ + MockParagraph("XYZ is a U.S. based company."), + MockParagraph("Test paragraph for the key XYZ."), + ], + "Test": [ + MockParagraph("Test is a U.S. based company."), + MockParagraph("Test paragraph for the key Test."), + ], + "Sometitle": [ + MockParagraph("SomeTitle is a U.S. based company."), + MockParagraph("Test paragraph for the key SomeTitle."), + ], + } + + def redirects(self) -> List[tuple]: + return self.redirect_info + + def get_paragraphs(self, title: str): + return self.paragraphs_map[title] diff --git a/tests/test_preprocessing_utils.py b/tests/test_preprocessing_utils.py new file mode 100644 index 00000000..2061f4e2 --- /dev/null +++ b/tests/test_preprocessing_utils.py @@ -0,0 +1,55 @@ +import unittest +from unittest import mock + +from datasets import Dataset +from mocks.mock_dump_db import MockDumpDB + +from bsmetadata.preprocessing_utils import WebsiteDescPreprocessor + + +def mock_sent_tokenize(text): + return [text] + + +class WebsiteDescPreprocessorTester(unittest.TestCase): + @mock.patch("bsmetadata.preprocessing_tools.website_desc_utils.DumpDB") + def setUp(self, mock_db) -> None: + mock_db.return_value = MockDumpDB("some/path") + self.website_processor = WebsiteDescPreprocessor() + self.example_ids = [0, 1, 2] + self.example_text = ["test text 1", "test text 2", "test text 3"] + self.example_metadata = [ + [{"key": "url", "type": "global", "value": "https://www.xyz.com"}], + [ + {"key": "url", "type": "global", "value": "http://sometitle.com"}, + {"key": "url", "type": "global", "value": "http://notfound.com"}, + ], + [{"key": "url", "type": "global", "value": "https://www.test.com"}], + ] + + self.example_dict = {"id": self.example_ids, "metadata": self.example_metadata, "text": self.example_text} + + @mock.patch("bsmetadata.preprocessing_tools.website_desc_utils.nltk.sent_tokenize", new=mock_sent_tokenize) + def test_website_metadata_processor(self): + ds = Dataset.from_dict(self.example_dict) + ds = ds.map(lambda ex: self.website_processor.preprocess(ex), batched=True) + target_metadata = [ + [ + {"key": "url", "type": "global", "value": "https://www.xyz.com"}, + {"key": "website_description", "type": "global", "value": "XYZ is a U.S. based company."}, + ], + [ + {"key": "url", "type": "global", "value": "http://sometitle.com"}, + {"key": "url", "type": "global", "value": "http://notfound.com"}, + {"key": "website_description", "type": "global", "value": "SomeTitle is a U.S. based company."}, + ], + [ + {"key": "url", "type": "global", "value": "https://www.test.com"}, + {"key": "website_description", "type": "global", "value": "Test is a U.S. based company."}, + ], + ] + self.assertEqual(ds[:]["metadata"], target_metadata) + + +if __name__ == "__main__": + unittest.main()