diff --git a/semanticscholar/AsyncSemanticScholar.py b/semanticscholar/AsyncSemanticScholar.py index cce4eb7..9170bb3 100644 --- a/semanticscholar/AsyncSemanticScholar.py +++ b/semanticscholar/AsyncSemanticScholar.py @@ -228,6 +228,7 @@ def _get_not_found_ids(self, paper_ids, papers): prefix_mapping = { 'ARXIV': 'ArXiv', + 'DOI': 'DOI', 'MAG': 'MAG', 'ACL': 'ACL', 'PMID': 'PubMed', @@ -241,11 +242,10 @@ def _get_not_found_ids(self, paper_ids, papers): found_ids.add(paper.paperId) if paper.externalIds: for prefix, value in paper.externalIds.items(): + found_ids.add(f'{value}') if prefix.lower() in prefix_mapping: found_ids.add( f'{prefix_mapping[prefix.lower()]}:{value}') - else: - found_ids.add(f'{value}') found_ids = {id.lower() for id in found_ids} not_found_ids = [id for id in paper_ids if id.lower() not in found_ids] diff --git a/tests/test_semanticscholar.py b/tests/test_semanticscholar.py index 55ad8e2..8bc9183 100644 --- a/tests/test_semanticscholar.py +++ b/tests/test_semanticscholar.py @@ -284,6 +284,18 @@ def test_get_papers_return_not_found(self): self.assertEqual(len(not_found), 1) self.assertEqual(not_found[0], 'CorpusId:211530585') + def test_get_papers_doi_prefix_not_false_positive(self): + paper = Paper({ + 'paperId': 'abc123', + 'externalIds': { + 'DOI': '10.1145/792550.792552', + 'CorpusId': 12345 + } + }) + not_found = self.sch._AsyncSemanticScholar._get_not_found_ids( + ['DOI:10.1145/792550.792552'], [paper]) + self.assertEqual(not_found, []) + @test_vcr.use_cassette def test_get_paper_authors(self): data = self.sch.get_paper_authors('10.2139/ssrn.2250500') @@ -829,7 +841,19 @@ async def test_get_papers_return_not_found_async(self): not_found = data[1] self.assertEqual(len(not_found), 1) self.assertEqual(not_found[0], 'CorpusId:211530585') - + + def test_get_papers_doi_prefix_not_false_positive_async(self): + paper = Paper({ + 'paperId': 'abc123', + 'externalIds': { + 'DOI': '10.1145/792550.792552', + 'CorpusId': 12345 + } + }) + not_found = self.sch._get_not_found_ids( + ['DOI:10.1145/792550.792552'], [paper]) + self.assertEqual(not_found, []) + @test_vcr.use_cassette async def test_get_paper_authors_async(self): data = await self.sch.get_paper_authors('10.2139/ssrn.2250500')