Skip to content

Commit

Permalink
Bring changes from hackathon
Browse files Browse the repository at this point in the history
  • Loading branch information
jramcast committed Jan 27, 2025
1 parent 8228f40 commit 6c72d49
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 710 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,6 @@ node_modules/
# Training
.output
.model
models
checkpoint-*
runs
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

31 changes: 17 additions & 14 deletions build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,28 @@
VALIDATION_PATH = "data/dataset_validation.txt"


def parse_sections(f):
sections = []
def parse_section(f):
clean_lines = []
for line in f:
line = line.rstrip()

line = line.rstrip().replace(":gls_prefix:", "")

if (line.startswith("//")
or line.startswith("ifndef")
or line.startswith(":experiment")):
continue

if re.match(r"^=+ \w+", line):
sections.append(line)
else:
try:
sections[-1] += "\n" + line
except IndexError:
pass
clean_lines.append(line)

# if re.match(r"^=+ \w+", line):
# sections.append(line)
# else:
# try:
# sections[-1] += "\n" + line
# except IndexError:
# pass

return sections
return "\n".join(clean_lines)


sections = []
Expand All @@ -51,14 +54,14 @@ def parse_sections(f):
filepath = os.path.join(dirpath, f)
print(filepath)
with open(filepath, "r") as f:
sections += parse_sections(f)
sections.append(parse_section(f))

random.Random(42).shuffle(sections)
num_sections = len(sections)
train_size = int(num_sections * 0.8)

with open(TRAIN_PATH, "w") as f:
f.write("\n".join(sections[:train_size]))
f.write("\n<|endoftext|>\n".join(sections[:train_size]))

with open(VALIDATION_PATH, "w") as f:
f.write("\n".join(sections[train_size:]))
f.write("\n<|endoftext|>\n".join(sections[train_size:]))
115 changes: 0 additions & 115 deletions build_dataset_by_section.py

This file was deleted.

40 changes: 28 additions & 12 deletions predict_clm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import random
from transformers import pipeline, set_seed
from transformers import logging as transformers_logging

# Suppress warning messages
transformers_logging.set_verbosity_error()

prompt = "To list all the pods in an OpenShift project"
max_input_chars = 500
wordcount = len(prompt.split())

TEXT = "=== Identifying the Need"
wordcount = len(TEXT.split())

# generator = pipeline('text-generation', model='gpt2')
generator = pipeline('text-generation', model='.output/')
generator = pipeline(
"text-generation", model=".model/", do_sample=True, temperature=0.8
)
set_seed(42)
predictions = generator("Correct any reported",
max_length=wordcount + 3, num_return_sequences=5)

for p in predictions:
print()
print("-" * 100)
print(p["generated_text"])
print("-" * 100)
print()


def generate(text: str):
predictions = generator(text, max_new_tokens=3, num_return_sequences=3)
return predictions[random.choice([0, 1, 2])]["generated_text"]


print(prompt, end=None)

while True:
input = prompt[-max_input_chars:]
generated = generate(input)
print(generated.replace(input, ""), end="", flush=True)
# print(generated)
# print("--" * 50)
prompt = generated
6 changes: 4 additions & 2 deletions train
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

#https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling

MODEL=gpt2

python train_clm.py \
--model_name_or_path gpt2 \
--model_name_or_path ${MODEL} \
--train_file data/dataset_train.txt \
--validation_file data/dataset_validation.txt \
--do_train \
--do_eval \
--use_fast_tokenizer \
--overwrite_output_dir \
--output_dir .model \
--output_dir .models/$(date '+%Y-%m-%d_%H-%M')_${MODEL} \
--per_device_eval_batch_size 20 \
--per_device_train_batch_size 5 \
--num_train_epochs 10 \
Expand Down
Loading

0 comments on commit 6c72d49

Please sign in to comment.