Skip to content

Commit

Permalink
reconstruct code
Browse files Browse the repository at this point in the history
  • Loading branch information
TideDra committed Dec 20, 2024
1 parent cfa693c commit 3300fc2
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 226 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dist/
wheels/
.vscode/
*.egg-info
.env

# Virtual environments
.venv
Expand Down
41 changes: 31 additions & 10 deletions construct_email.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import arxiv
from paper import ArxivPaper
import math
from tqdm import tqdm
from email.header import Header
from email.mime.text import MIMEText
from email.utils import parseaddr, formataddr
import smtplib
import datetime

framework = """
<!DOCTYPE HTML>
<html>
Expand Down Expand Up @@ -106,19 +113,12 @@ def get_stars(score:float):
return '<div class="star-wrapper">'+full_star * full_star_num + half_star * half_star_num + '</div>'


def render_email(papers:list[arxiv.Result]):
def render_email(papers:list[ArxivPaper]):
parts = []
if len(papers) == 0 :
return framework.replace('__CONTENT__', get_empty_html())

for p in papers:
# crop the abstract
'''
summary = p.summary
summary = summary[:min(600, len(summary))]
if len(summary) == 600:
summary += '...'
'''
for p in tqdm(papers,desc='Rendering Email'):
rate = get_stars(p.score)
authors = [a.name for a in p.authors[:5]]
authors = ', '.join(authors)
Expand All @@ -128,3 +128,24 @@ def render_email(papers:list[arxiv.Result]):

content = '<br>' + '</br><br>'.join(parts) + '</br>'
return framework.replace('__CONTENT__', content)

def send_email(sender:str, receiver:str, password:str,smtp_server:str,smtp_port:int, html:str,):
def _format_addr(s):
name, addr = parseaddr(s)
return formataddr((Header(name, 'utf-8').encode(), addr))

msg = MIMEText(html, 'html', 'utf-8')
msg['From'] = _format_addr('Github Action <%s>' % sender)
msg['To'] = _format_addr('You <%s>' % receiver)
today = datetime.datetime.now().strftime('%Y/%m/%d')
msg['Subject'] = Header(f'Daily arXiv {today}', 'utf-8').encode()

try:
server = smtplib.SMTP(smtp_server, smtp_port)
server.starttls()
except smtplib.SMTPServerDisconnected:
server = smtplib.SMTP_SSL(smtp_server, smtp_port)

server.login(sender, password)
server.sendmail(sender, [receiver], msg.as_string())
server.quit()
37 changes: 37 additions & 0 deletions llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from llama_cpp import Llama
from openai import OpenAI
from loguru import logger

GLOBAL_LLM = None

class LLM:
def __init__(self, api_key: str = None, base_url: str = None, model: str = None):
if api_key:
self.llm = OpenAI(api_key=api_key, base_url=base_url)
else:
self.llm = Llama.from_pretrained(
repo_id="Qwen/Qwen2.5-3B-Instruct-GGUF",
filename="qwen2.5-3b-instruct-q4_k_m.gguf",
n_ctx=32_000,
n_threads=4,
verbose=False,
)
self.model = model

def generate(self, messages: list[dict]) -> str:
if isinstance(self.llm, OpenAI):
response = self.llm.chat.completions.create(messages=messages,temperature=0,model=self.model)
return response.choices[0].message.content
else:
response = self.llm.create_chat_completion(messages=messages,temperature=0)
return response["choices"][0]["message"]["content"]

def set_global_llm(api_key: str = None, base_url: str = None, model: str = None):
global GLOBAL_LLM
GLOBAL_LLM = LLM(api_key=api_key, base_url=base_url, model=model)

def get_llm() -> LLM:
if GLOBAL_LLM is None:
logger.info("No global LLM found, creating a default one. Use `set_global_llm` to set a custom one.")
set_global_llm()
return GLOBAL_LLM
188 changes: 62 additions & 126 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import arxiv
import argparse
import os
import sys
from dotenv import load_dotenv
load_dotenv(override=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pyzotero import zotero
from recommender import rerank_paper
from construct_email import render_email
from construct_email import render_email, send_email
import requests
import datetime
import re
from time import sleep
from email.header import Header
from email.mime.text import MIMEText
from email.utils import parseaddr, formataddr
import smtplib
from tldr import get_paper_tldr
from llama_cpp import Llama
from tqdm import tqdm, trange
from tqdm import trange
from loguru import logger
from openai import OpenAI
from gitignore_parser import parse_gitignore
from tempfile import mkstemp
from paper import ArxivPaper
from llm import set_global_llm

def get_zotero_corpus(id:str,key:str) -> list[dict]:
zot = zotero.Zotero(id, 'user', key)
Expand Down Expand Up @@ -49,44 +48,6 @@ def filter_corpus(corpus:list[dict], pattern:str) -> list[dict]:
os.remove(filename)
return new_corpus

def select_corpus(corpus:list[dict], tags: str) -> list[dict]:
tag = tags.split(',')
new_corpus = []
for c in corpus:
for p in c['paths']:
if p in tag:
new_corpus.append(c)
continue
return new_corpus

def get_paper_code_url(paper:arxiv.Result) -> str:
retry_num = 5
while retry_num > 0:
try:
paper_list = requests.get(f'https://paperswithcode.com/api/v1/papers/?arxiv_id={paper.arxiv_id}').json()
break
except:
sleep(1)
retry_num -= 1
if retry_num == 0:
return None

if paper_list.get('count',0) == 0:
return None
paper_id = paper_list['results'][0]['id']
retry_num = 5
while retry_num > 0:
try:
repo_list = requests.get(f'https://paperswithcode.com/api/v1/papers/{paper_id}/repositories/').json()
break
except:
sleep(1)
retry_num -= 1
if retry_num == 0:
return None
if repo_list.get('count',0) == 0:
return None
return repo_list['results'][0]['url']

def get_arxiv_paper_from_web(query:str, start:datetime.datetime, end:datetime.datetime) -> list[arxiv.Result]:
cats = re.findall(r'cat:(\w+)?\.\w+?', query)
Expand Down Expand Up @@ -148,9 +109,7 @@ def is_valid(paper:arxiv.Result):
search = arxiv.Search(id_list=all_paper_ids[i:i+50])
for i in client.results(search):
if is_valid(i):
i.arxiv_id = re.sub(r'v\d+$', '', i.get_short_id())
i.code_url = get_paper_code_url(i)
results.append(i)
results.append(ArxivPaper(i))
return results


Expand All @@ -165,9 +124,7 @@ def get_arxiv_paper(query:str, start:datetime.datetime, end:datetime.datetime, d
for i in client.results(search):
published_date = i.published
if published_date < end and published_date >= start:
i.arxiv_id = re.sub(r'v\d+$', '', i.get_short_id())
i.code_url = get_paper_code_url(i)
papers.append(i)
papers.append(ArxivPaper(i))
elif published_date < start:
break
break
Expand All @@ -186,9 +143,7 @@ def get_arxiv_paper(query:str, start:datetime.datetime, end:datetime.datetime, d
papers = []
try:
for i in client.results(search):
i.arxiv_id = re.sub(r'v\d+$', '', i.get_short_id())
i.code_url = get_paper_code_url(i)
papers.append(i)
papers.append(ArxivPaper(i))
if len(papers) == 5:
break
break
Expand All @@ -200,85 +155,81 @@ def get_arxiv_paper(query:str, start:datetime.datetime, end:datetime.datetime, d
raise e
return papers

def send_email(sender:str, receiver:str, password:str,smtp_server:str,smtp_port:int, html:str,):
def _format_addr(s):
name, addr = parseaddr(s)
return formataddr((Header(name, 'utf-8').encode(), addr))

msg = MIMEText(html, 'html', 'utf-8')
msg['From'] = _format_addr('Github Action <%s>' % sender)
msg['To'] = _format_addr('You <%s>' % receiver)
today = datetime.datetime.now().strftime('%Y/%m/%d')
msg['Subject'] = Header(f'Daily arXiv {today}', 'utf-8').encode()

try:
server = smtplib.SMTP(smtp_server, smtp_port)
server.starttls()
except smtplib.SMTPServerDisconnected:
server = smtplib.SMTP_SSL(smtp_server, smtp_port)

server.login(sender, password)
server.sendmail(sender, [receiver], msg.as_string())
server.quit()
parser = argparse.ArgumentParser(description='Recommender system for academic papers')


def get_env(key:str,default=None):
# handle environment variables generated at Workflow runtime
# Unset environment variables are passed as '', we should treat them as None
v = os.environ.get(key)
if v == '' or v is None:
return default
return v
def add_argument(*args, **kwargs):
def get_env(key:str,default=None):
# handle environment variables generated at Workflow runtime
# Unset environment variables are passed as '', we should treat them as None
v = os.environ.get(key)
if v == '' or v is None:
return default
return v
parser.add_argument(*args, **kwargs)
arg_full_name = kwargs.get('dest',args[-1][2:])
env_name = arg_full_name.upper()
env_value = get_env(env_name)
if env_value is not None:
#convert env_value to the specified type
if kwargs.get('type') == bool:
env_value = env_value.lower() in ['true','1']
else:
env_value = kwargs.get('type')(env_value)
parser.set_defaults(**{arg_full_name:env_value})


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Recommender system for academic papers')
parser.add_argument('--zotero_id', type=str, help='Zotero user ID',default=get_env('ZOTERO_ID'))
parser.add_argument('--zotero_key', type=str, help='Zotero API key',default=get_env('ZOTERO_KEY'))
parser.add_argument('--zotero_ignore',type=str,help='Zotero collection to ignore, using gitignore-style pattern.',default=get_env('ZOTERO_IGNORE'))
parser.add_argument('--send_empty', type=bool, help='If get no arxiv paper, send empty email',default=get_env('SEND_EMPTY',False))
parser.add_argument('--max_paper_num', type=int, help='Maximum number of papers to recommend',default=get_env('MAX_PAPER_NUM',100))
parser.add_argument('--arxiv_query', type=str, help='Arxiv search query',default=get_env('ARXIV_QUERY'))
parser.add_argument('--smtp_server', type=str, help='SMTP server',default=get_env('SMTP_SERVER'))
parser.add_argument('--smtp_port', type=int, help='SMTP port',default=get_env('SMTP_PORT'))
parser.add_argument('--sender', type=str, help='Sender email address',default=get_env('SENDER'))
parser.add_argument('--receiver', type=str, help='Receiver email address',default=get_env('RECEIVER'))
parser.add_argument('--password', type=str, help='Sender email password',default=get_env('SENDER_PASSWORD'))
parser.add_argument(

add_argument('--zotero_id', type=str, help='Zotero user ID')
add_argument('--zotero_key', type=str, help='Zotero API key')
add_argument('--zotero_ignore',type=str,help='Zotero collection to ignore, using gitignore-style pattern.')
add_argument('--send_empty', type=bool, help='If get no arxiv paper, send empty email',default=False)
add_argument('--max_paper_num', type=int, help='Maximum number of papers to recommend',default=100)
add_argument('--arxiv_query', type=str, help='Arxiv search query')
add_argument('--smtp_server', type=str, help='SMTP server')
add_argument('--smtp_port', type=int, help='SMTP port')
add_argument('--sender', type=str, help='Sender email address')
add_argument('--receiver', type=str, help='Receiver email address')
add_argument('--sender_password', type=str, help='Sender email password')
add_argument(
"--use_llm_api",
type=bool,
help="Use OpenAI API to generate TLDR",
default=get_env("USE_LLM_API", False),
default=False,
)
parser.add_argument(
add_argument(
"--openai_api_key",
type=str,
help="OpenAI API key",
default=get_env("OPENAI_API_KEY"),
default=None,
)
parser.add_argument(
add_argument(
"--openai_api_base",
type=str,
help="OpenAI API base URL",
default=get_env("OPENAI_API_BASE", "https://api.openai.com/v1"),
default="https://api.openai.com/v1",
)
parser.add_argument(
add_argument(
"--model_name",
type=str,
help="LLM Model Name",
default=get_env("MODEL_NAME", "gpt-4o"),
default="gpt-4o",
)
parser.add_argument('--debug', action='store_true', help='Debug mode')
args = parser.parse_args()

assert args.zotero_id is not None
assert args.zotero_key is not None
assert args.arxiv_query is not None
assert (
not args.use_llm_api or args.openai_api_key is not None
) # If use_llm_api is True, openai_api_key must be provided
if args.debug:
logger.remove()
logger.add(sys.stdout, level="DEBUG")
logger.debug("Debug mode is on.")
else:
logger.remove()
logger.add(sys.stdout, level="WARNING")

today = datetime.datetime.now(tz=datetime.timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
yesterday = today - datetime.timedelta(days=1)
logger.info("Retrieving Zotero corpus...")
Expand All @@ -299,30 +250,15 @@ def get_env(key:str,default=None):
papers = rerank_paper(papers, corpus)
if args.max_paper_num != -1:
papers = papers[:args.max_paper_num]

logger.info("Generating TLDRs...")
if args.use_llm_api:
logger.info("Using OpenAI API to generate TLDRs...")
llm = OpenAI(
api_key=args.openai_api_key,
base_url=args.openai_api_base,
)
for p in tqdm(papers):
p.tldr = get_paper_tldr(p, llm, model_name=args.model_name)
logger.info("Using OpenAI API as global LLM.")
set_global_llm(api_key=args.openai_api_key, base_url=args.openai_api_base, model=args.model_name)
else:
logger.info("Using Local LLM model to generate TLDRs...")
llm = Llama.from_pretrained(
repo_id="Qwen/Qwen2.5-3B-Instruct-GGUF",
filename="qwen2.5-3b-instruct-q4_k_m.gguf",
n_ctx=4096,
n_threads=4,
verbose=False
)
for p in tqdm(papers):
p.tldr = get_paper_tldr(p, llm)
logger.info("Using Local LLM as global LLM.")
set_global_llm()

html = render_email(papers)
logger.info("Sending email...")
send_email(args.sender, args.receiver, args.password, args.smtp_server, args.smtp_port, html)
send_email(args.sender, args.receiver, args.sender_password, args.smtp_server, args.smtp_port, html)
logger.success("Email sent successfully! If you don't receive the email, please check the configuration and the junk box.")

Loading

0 comments on commit 3300fc2

Please sign in to comment.