diff --git a/pycorrector/config.py b/pycorrector/config.py index f604c6c7..af7971a5 100644 --- a/pycorrector/config.py +++ b/pycorrector/config.py @@ -49,7 +49,7 @@ macbert_model_dir = os.path.join(USER_DATA_DIR, 'macbert_models/chinese_finetuned_correction/') os.makedirs(macbert_model_dir, exist_ok=True) # t5模型文件路径 -t5_model_dir = os.path.join(USER_DATA_DIR, 't5_models/byt5-small-chinese-correction/') +t5_model_dir = os.path.join(USER_DATA_DIR, 't5_models/mengzi-t5-base-chinese-correction/') os.makedirs(t5_model_dir, exist_ok=True) # convseq2seq模型文件夹路径 convseq2seq_model_dir = os.path.join(USER_DATA_DIR, 'seq2seq_models/convseq2seq_correction/') diff --git a/pycorrector/t5/infer.py b/pycorrector/t5/infer.py index 76c25ee5..05d52d9c 100644 --- a/pycorrector/t5/infer.py +++ b/pycorrector/t5/infer.py @@ -13,7 +13,7 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--save_dir', type=str, default='output', help='save dir') + parser.add_argument('--save_dir', type=str, default='output/mengzi-t5-base-chinese-correction/', help='save dir') args = parser.parse_args() return args @@ -26,7 +26,7 @@ def predict(): "他带了黑色的包,也带了照像机", ] args = parse_args() - model_dir = os.path.join(args.save_dir, './byt5-small-chinese-correction') + model_dir = args.save_dir tokenizer = AutoTokenizer.from_pretrained(model_dir) model = T5ForConditionalGeneration.from_pretrained(model_dir) model.to(device) diff --git a/pycorrector/t5/t5_corrector.py b/pycorrector/t5/t5_corrector.py index 80941c9c..7910e3e2 100644 --- a/pycorrector/t5/t5_corrector.py +++ b/pycorrector/t5/t5_corrector.py @@ -42,17 +42,17 @@ def get_errors(corrected_text, origin_text): class T5Corrector(object): def __init__(self, model_dir=config.t5_model_dir): - self.name = 'byt5_corrector' + self.name = 't5_corrector' t1 = time.time() bin_path = os.path.join(model_dir, 'pytorch_model.bin') if not os.path.exists(bin_path): - model_dir = "shibing624/byt5-small-chinese-correction" + model_dir = "shibing624/mengzi-t5-base-chinese-correction" logger.warning(f'local model {bin_path} not exists, use default HF model {model_dir}') self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = T5ForConditionalGeneration.from_pretrained(model_dir) self.model.to(device) logger.debug("Use device: {}".format(device)) - logger.debug('Loaded byt5 correction model: %s, spend: %.3f s.' % (model_dir, time.time() - t1)) + logger.debug('Loaded t5 correction model: %s, spend: %.3f s.' % (model_dir, time.time() - t1)) def t5_correct(self, text: str, max_length: int = 128): """ @@ -107,7 +107,7 @@ def batch_t5_correct(self, texts: List[str], max_length: int = 128): if __name__ == "__main__": - m = T5Corrector('./output/byt5-small-chinese-correction/') + m = T5Corrector('./output/mengzi-t5-base-chinese-correction/') error_sentences = [ '少先队员因该为老人让坐', '少 先 队 员 因 该 为 老人让坐', diff --git a/pycorrector/t5/train.py b/pycorrector/t5/train.py index 7d4cd75f..278bcaea 100644 --- a/pycorrector/t5/train.py +++ b/pycorrector/t5/train.py @@ -11,7 +11,9 @@ from transformers import AutoTokenizer, T5ForConditionalGeneration from transformers import HfArgumentParser, TrainingArguments, Trainer, set_seed from datasets import load_dataset, Dataset +from loguru import logger +os.environ["TOKENIZERS_PARALLELISM"] = "FALSE" pwd_path = os.path.abspath(os.path.dirname(__file__)) @@ -123,13 +125,21 @@ class ModelArguments: def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--train_path', type=str, - default=os.path.join(pwd_path, '../data/cn/sighan_2015/train.tsv'), - help='SIGHAN dataset') - parser.add_argument('--test_path', type=str, - default=os.path.join(pwd_path, '../data/cn/sighan_2015/test.tsv'), - help='SIGHAN dataset') - parser.add_argument('--save_dir', type=str, default='output', help='save dir') + parser.add_argument('--train_path', type=str, default=os.path.join(pwd_path, '../data/cn/sighan_2015/train.tsv'), + help='train dataset') + parser.add_argument('--test_path', type=str, default=os.path.join(pwd_path, '../data/cn/sighan_2015/test.tsv'), + help='test dataset') + parser.add_argument('--save_dir', type=str, default='./output/mengzi-t5-base-chinese-correction/', help='save dir') + parser.add_argument('--model_name_or_path', type=str, default='Langboat/mengzi-t5-base', help='pretrained model') + parser.add_argument('--max_len', type=int, default=128, help='max length') + parser.add_argument('--batch_size', type=int, default=32, help='batch size') + parser.add_argument('--logging_steps', type=int, default=100, help='logging steps num') + parser.add_argument('--warmup_steps', type=int, default=200, help='logging steps num') + parser.add_argument('--eval_steps', type=int, default=250, help='eval steps num') + parser.add_argument('--epochs', type=int, default=10, help='train epochs num') + parser.add_argument('--max_steps', type=int, default=5000, help='train max steps') + parser.add_argument("--do_train", action="store_true", help="whether not to do train") + parser.add_argument("--do_eval", action="store_true", help="whether not to do eval") args = parser.parse_args() return args @@ -149,24 +159,24 @@ def load(self): def train(): args = parse_args() args_dict = { - "model_name_or_path": 'Langboat/mengzi-t5-base', - "max_len": 128, - "output_dir": os.path.join(args.save_dir, './mengzi-t5-base-chinese-correction'), + "model_name_or_path": args.model_name_or_path, + "max_len": args.max_len, + "output_dir": args.save_dir, "overwrite_output_dir": True, - "per_device_train_batch_size": 64, - "per_device_eval_batch_size": 64, + "per_device_train_batch_size": args.batch_size, + "per_device_eval_batch_size": args.batch_size, "gradient_accumulation_steps": 4, "learning_rate": 5e-4, - "warmup_steps": 250, - "logging_steps": 100, + "warmup_steps": args.warmup_steps, + "logging_steps": args.logging_steps, "evaluation_strategy": "steps", - "eval_steps": 250, - "num_train_epochs": 4, - "do_train": True, - "do_eval": True, + "eval_steps": args.eval_steps, + "num_train_epochs": args.epochs, + "do_train": args.do_train, + "do_eval": args.do_eval, "fp16": False, "use_cache": False, - # "max_steps": 5000, + "max_steps": args.max_steps, } parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) @@ -174,7 +184,7 @@ def train(): set_seed(training_args.seed) if args.train_path.endswith('.tsv'): dataset = load_dataset('text', data_files={'train': [args.train_path], 'test': args.test_path}) - print(dataset) + logger.info(dataset) train_dataset = dataset['train'] valid_dataset = dataset['test'] elif args.train_path.endswith('.json'): @@ -185,8 +195,8 @@ def train(): d = CscDataset(args.test_path) data_dict = d.load() valid_dataset = Dataset.from_dict(data_dict, split='test') - print(train_dataset) - print(valid_dataset) + logger.info(train_dataset) + logger.info(valid_dataset) else: raise ValueError('train_path must be tsv or json') @@ -205,7 +215,7 @@ def train(): tokenizer.model_max_length = 128 model.config.max_length = 128 - print('train_dataset:', train_dataset[:3]) + logger.info(f'train_dataset: {train_dataset[:3]}') def tokenize_dataset(tokenizer, dataset, max_len): def convert_to_features(example_batch): diff --git a/requirements.txt b/requirements.txt index fdc9509e..c1608b30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ jieba>=0.39 pypinyin numpy -six \ No newline at end of file +six +loguru \ No newline at end of file diff --git a/setup.py b/setup.py index 5a92cbf5..952253ba 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,8 @@ "jieba", "pypinyin", "numpy", - "six" + "six", + "loguru", ], packages=find_packages(exclude=['tests']), package_dir={'pycorrector': 'pycorrector'},