Skip to content

Commit

Permalink
Add MetaData LLM call
Browse files Browse the repository at this point in the history
Related to CatchTheTornado#16

Add an optional LLM call for generating tags and summary of the file.

* **app/main.py**
  - Add a new endpoint `/llm_tags_summary` to generate tags and summary using the LLM.
  - Update the `OllamaGenerateRequest` class to include a new field `generate_tags_summary`.
  - Update the `generate_llama` function to handle the new `generate_tags_summary` field.

* **app/tasks.py**
  - Add a new function `generate_tags_summary` to generate tags and summary using the LLM.
  - Update the `ocr_task` function to include an optional call to `generate_tags_summary` after extracting text.

* **client/cli.py**
  - Add a new command `llm_tags_summary` for generating tags and summary.
  - Update the `main` function to handle the new `llm_tags_summary` command.

* **.env.example**
  - Add a new environment variable `LLM_TAGS_SUMMARY_API_URL`.
  • Loading branch information
chavan-arvind committed Nov 5, 2024
1 parent 7583f09 commit 7eeac00
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ RESULT_URL=http://localhost:8000/ocr/result/{task_id}
CLEAR_CACHE_URL=http://localhost:8000/ocr/clear_cach
LLM_PULL_API_URL=http://localhost:8000/llm_pull
LLM_GENEREATE_API_URL=http://localhost:8000/llm_generate
LLM_TAGS_SUMMARY_API_URL=http://localhost:8000/llm_tags_summary
23 changes: 23 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def clear_ocr_cache():
class OllamaGenerateRequest(BaseModel):
model: str
prompt: str
generate_tags_summary: bool = False

class OllamaPullRequest(BaseModel):
model: str
Expand Down Expand Up @@ -116,3 +117,25 @@ async def generate_llama(request: OllamaGenerateRequest):

generated_text = response.get("response", "")
return {"generated_text": generated_text}

@app.post("/llm_tags_summary")
async def generate_tags_summary(request: OllamaGenerateRequest):
"""
Endpoint to generate tags and summary using Llama 3.1 model (and other models) via the Ollama API.
"""
print(request)
if not request.prompt:
raise HTTPException(status_code=400, detail="No prompt provided")

try:
response = ollama.generate(request.model, request.prompt)
except ollama.ResponseError as e:
print('Error:', e.error)
if e.status_code == 404:
print("Error: ", e.error)
ollama.pull(request.model)

raise HTTPException(status_code=500, detail="Failed to generate tags and summary with Ollama API")

generated_text = response.get("response", "")
return {"generated_text": generated_text}
22 changes: 21 additions & 1 deletion app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,27 @@ def ocr_task(self, pdf_bytes, strategy_name, pdf_hash, ocr_cache, prompt, model)
num_chunk += 1
extracted_text += chunk['response']

self.update_state(state='DONE', meta={'progress': 100 , 'status': 'Processing done!', 'start_time': start_time, 'elapsed_time': time.time() - start_time}) # Example progress update
# Optional call to generate tags and summary
if prompt and model:
tags_summary = generate_tags_summary(prompt, model)
extracted_text += "\n\nTags and Summary:\n" + tags_summary

self.update_state(state='DONE', meta={'progress': 100 , 'status': 'Processing done!', 'start_time': start_time, 'elapsed_time': time.time() - start_time}) # Example progress update

return extracted_text

def generate_tags_summary(prompt, model):
"""
Function to generate tags and summary using the LLM.
"""
try:
response = ollama.generate(model, prompt)
except ollama.ResponseError as e:
print('Error:', e.error)
if e.status_code == 404:
print("Error: ", e.error)
ollama.pull(model)
raise Exception("Failed to generate tags and summary with Ollama API")

generated_text = response.get("response", "")
return generated_text
16 changes: 15 additions & 1 deletion client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def llm_generate(prompt, model = 'llama3.1'):
else:
print(f"Failed to generate text: {response.text}")

def llm_tags_summary(prompt, model = 'llama3.1'):
ollama_tags_summary_url = os.getenv('LLM_TAGS_SUMMARY_API_URL', 'http://localhost:8000/llm_tags_summary')
response = requests.post(ollama_tags_summary_url, json={"model": model, "prompt": prompt})
if response.status_code == 200:
print(response.json().get('generated_text'))
else:
print(f"Failed to generate tags and summary: {response.text}")

def main():
parser = argparse.ArgumentParser(description="CLI for OCR and Ollama operations.")
subparsers = parser.add_subparsers(dest='command', help='Sub-command help')
Expand Down Expand Up @@ -114,6 +122,10 @@ def main():
ollama_pull_parser = subparsers.add_parser('llm_pull', help='Pull the latest Llama model from the Ollama API')
ollama_pull_parser.add_argument('--model', type=str, default='llama3.1', help='Model to pull from the Ollama API')

# Sub-command for generating tags and summary
ollama_tags_summary_parser = subparsers.add_parser('llm_tags_summary', help='Generate tags and summary using the Ollama endpoint')
ollama_tags_summary_parser.add_argument('--prompt', type=str, required=True, help='Prompt for the Ollama model')
ollama_tags_summary_parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for the Ollama endpoint')

args = parser.parse_args()

Expand All @@ -140,8 +152,10 @@ def main():
llm_generate(args.prompt, args.model)
elif args.command == 'llm_pull':
llm_pull(args.model)
elif args.command == 'llm_tags_summary':
llm_tags_summary(args.prompt, args.model)
else:
parser.print_help()

if __name__ == "__main__":
main()
main()

0 comments on commit 7eeac00

Please sign in to comment.