Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jramcast committed May 17, 2021
1 parent ced282d commit 4d62222
Show file tree
Hide file tree
Showing 18 changed files with 715 additions and 109 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.model/checkpoint-*
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ node_modules/

# Training
.output
.model
runs
13 changes: 13 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM registry.access.redhat.com/ubi8/python-38

COPY requirements.prod.txt .
RUN pip install --upgrade pip && \
pip install -r requirements.prod.txt && \
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html

COPY .model .model
COPY serve.py .

EXPOSE 8000

CMD ["python", "serve.py"]
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# RHT Text Generator

Tool to assist developers when writing courses.

## Usage

1. Build the model server image:

podman build . -t rht-text-generator

2. Run the container:

podman run --rm -ti -p 8482:8000 rht-text-generator

3. Install the extension:




## Retrain the model

1. Build the dataset from courses:

COURSE_DIR=... python build_dataset.py

2. Train:

./train
5 changes: 4 additions & 1 deletion build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def parse_sections(f):

# Find adoc files
home = str(Path.home())
coursedir = os.path.join(home, "Desarrollo")
coursedir = os.environ.get(
"COURSE_DIR",
os.path.join(home, "Desarrollo", "courses"))


for dirpath, dnames, fnames in os.walk(coursedir):
for f in fnames:
Expand Down
115 changes: 115 additions & 0 deletions build_dataset_by_section.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import re
import random
import statistics
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer


TRAIN_PATH = "data/train/"
VALIDATION_PATH = "data/validation/"

# Find adoc files
home = str(Path.home())
coursedir = os.environ.get(
"COURSE_DIR",
os.path.join(home, "Desarrollo", "courses"))


lecture_pattern = re.compile(r"== \w+")
lab_pattern = re.compile(r"(^\d\) \w+)|(^== Outcomes)")


def parse_sections(filehandler, pattern):
section = ""
sections = []
ignore_lines = True

for line in filehandler:
if (line.startswith("//")
or line.startswith("ifndef")
or line.startswith(":experiment")):
continue

if pattern.match(line):
ignore_lines = False
if section:
sections.append(section)
section = ""

if not ignore_lines:
section += line.rstrip(" ")

return sections


def get_block_sizes(sections):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
total = 0
return [tokenizer(section, return_length=True)["length"]
for section in sections]


if __name__ == "__main__":

sections = []
for dirpath, dnames, fnames in os.walk(coursedir):
for f in fnames:
if (f.endswith("content.adoc") and
"guides" in dirpath and
"en-US" in dirpath):
filepath = os.path.join(dirpath, f)

if "zzz" in filepath:
continue

print(filepath)

with open(filepath, "r") as f:

if "lab-content" in filepath or "ge-content" in filepath:
print(filepath)
sections += parse_sections(f, lab_pattern)
else:
sections += parse_sections(f, lecture_pattern)

sizes = get_block_sizes(sections)
print("Mean block size:", statistics.mean(sizes))
print("Median block size:", statistics.median(sizes))

def pdf(x):
mean = np.mean(x)
std = np.std(x)
y_out = 1/(std * np.sqrt(2 * np.pi)) * np.exp( - (x - mean)**2 / (2 * std**2))
return y_out

plt.style.use('seaborn')
y = pdf(sizes)
plt.figure(figsize=(6, 6))
# plt.plot(sizes, y, color='black',
# linestyle='dashed')

plt.scatter(sizes, y, marker='o', s=25, color='red')
plt.show()

random.Random(42).shuffle(sections)
num_sections = len(sections)
train_size = int(num_sections * 0.8)
train_sections = sections[:train_size]
validation_sections = sections[train_size:]

import pandas as pd
train_df = pd.DataFrame(train_sections)
train_df.to_csv(TRAIN_PATH + "train.csv", index=False)
valid_df = pd.DataFrame(validation_sections)
valid_df.to_csv(VALIDATION_PATH + "validation.csv", index=False)

for key, section in enumerate(train_sections):
with open(TRAIN_PATH + f"section_{key}.txt", "w") as f:
f.write(section)

for key, section in enumerate(validation_sections):
with open(VALIDATION_PATH + f"section_{key}.txt", "w") as f:
f.write(section)
70 changes: 0 additions & 70 deletions extension/rht-text-generator/README.md

This file was deleted.

14 changes: 14 additions & 0 deletions extension/rht-text-generator/README.old.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# RHT Text generator


1. Build the model server image:

podman build ../.. -t rht-text-generator

2. Run the container:

podman run --rm -ti -p 8482:8000 rht-text-generator

3. Install the extension:

code --install-extension rht-text-generator-0.0.1.vsix
23 changes: 23 additions & 0 deletions extension/rht-text-generator/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"name": "rht-text-generator",
"displayName": "rht-text-generator",
"publisher": "red-hat-training",
"description": "",
"version": "0.0.1",
"engines": {
Expand Down Expand Up @@ -36,5 +37,27 @@
},
"dependencies": {
"axios": "^0.21.1"
},
"contributes": {
"configuration": {
"title": "rht-text-generator",
"properties": {
"rht-text-generator.lines": {
"type": "number",
"default": 3,
"description": "Number of lines to pass to the model"
},
"rht-text-generator.length": {
"type": "number",
"default": 3,
"description": "Number of words(tokens) to generate. Higher is slower"
},
"rht-text-generator.server": {
"type": "string",
"default": "localhost:8482",
"description": "Model server"
}
}
}
}
}
Binary file not shown.
36 changes: 16 additions & 20 deletions extension/rht-text-generator/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,33 @@ async function getCompletionsListItemsFor(
position: vscode.Position
): Promise<vscode.CompletionItem[]> {

const CHAR_LIMIT = 1000;
const offset = document.offsetAt(position);
const beforeStartOffset = Math.max(0, offset - CHAR_LIMIT);

// const afterEndOffset = offset + CHAR_LIMIT;
// const beforeStart = document.positionAt(beforeStartOffset);
// const afterEnd = document.positionAt(afterEndOffset);

// const line = document.lineAt(position.line).text.trim();
const line = document.getText(
new vscode.Range(
document.positionAt(beforeStartOffset),
document.positionAt(offset)
)
);
const config = vscode.workspace.getConfiguration("rht-text-generator");
const MAX_LINES: number = config.get("lines") || 3;
const lines = [];

for (let lineOffset = 0; lineOffset < MAX_LINES; lineOffset++ ) {
const lineNumber = Math.max(0, position.line - lineOffset);
const line = document.lineAt(lineNumber).text.trimEnd();
lines.unshift(line);
}
const text = lines.join("\n");

const predictionLength = 3;
const predictionLength: number = config.get("length") || 3;

let suggestions: string[] = await generateSuggestions(line, predictionLength);
const server: string = config.get("server") || "";
let suggestions: string[] = await generateSuggestions(text, predictionLength, server);

return suggestions.map(suggestion => {
const tail = suggestion.replace(line, "").trim();
const tail = suggestion.replace(text, "").trim();
return new vscode.CompletionItem(tail);
});
}

async function generateSuggestions(line: string, predictionLength: number) {
async function generateSuggestions(line: string, predictionLength: number, server: string) {
let suggestions: string[] = [];
try {
const response = await Axios.get<[string]>(
`http://localhost:8000/?text=${line}&length=${predictionLength}`
`http://${server}/?text=${line}&length=${predictionLength}`
);
suggestions = response.data;
} catch (error) {
Expand Down
8 changes: 2 additions & 6 deletions extension/rht-text-generator/src/triggers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export const COMPLETION_TRIGGERS = [
",",
";",
"-",
"\n",
"(",
")",
"{",
Expand All @@ -24,10 +25,5 @@ export const COMPLETION_TRIGGERS = [
"|",
"&",
"*",
"%",
"=",
"$",
"#",
"@",
"!",
"="
];
1 change: 0 additions & 1 deletion predict_clm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pprint import pprint
from transformers import pipeline, set_seed

TEXT = "=== Identifying the Need"
Expand Down
2 changes: 2 additions & 0 deletions requirements.prod.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sanic==20.12.3
transformers==4.6.0
Loading

0 comments on commit 4d62222

Please sign in to comment.