diff --git a/A_allencahn/AC.mat b/A_allencahn/AC.mat new file mode 100644 index 0000000..cd18945 Binary files /dev/null and b/A_allencahn/AC.mat differ diff --git a/A_allencahn/ac.py b/A_allencahn/ac.py new file mode 100644 index 0000000..a5ac458 --- /dev/null +++ b/A_allencahn/ac.py @@ -0,0 +1,641 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from scipy.io import loadmat +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = int(args.maxiter) +net_seq: list = list(args.net_seq) +sample_num: int = int(args.sample_num) +resample_interval: int = int(args.resample_interval) +freq_draw: int = int(args.freq_draw) +verbose: bool = bool(args.verbose) +resample_N: int = int(args.resample_N) + +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_ac(): + if len(_memo) == 0: + data = loadmat('AC.mat') + tt = data['tt'] + x = data['x'] + u = data['uu'].T + t = tt.ravel() + x = x.ravel() + xx, tt = np.meshgrid(x, t) + _memo.append((xx.reshape(-1,1), tt.reshape(-1,1), u.reshape(-1,1))) + return _memo[0] + + +def compute_lhs(u, x, t): + u_pred = u(torch.cat([x, t], dim=1)) + u_xx = gradients(u_pred, x, 2) + u_t = gradients(u_pred, t, 1) + lhs = u_t - 0.0001 * u_xx + 5 * u_pred ** 3 - 5 * u_pred + return lhs + + +def compute_res(u): + resample_N = 500 + xc = torch.linspace(-1, 1, resample_N) + tc = torch.linspace(0, 1, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = torch.abs(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + +def ddu_fn(x, y): + rhs = torch.zeros_like(x) + return rhs + + +def l2_loss(u): + xx, tt, u_truth = exact_ac() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum((u_pred - u_truth) ** 2) / torch.sum((u_truth**2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_ac() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mse = torch.mean((u_pred - u_truth) ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + + +# 定义区域及其上的采样 +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.) + yy = torch.rand(n, 1) + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N) + tc = torch.linspace(0, 1, N) + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1), torch.Tensor(xy[:, 1:2]) + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1), torch.Tensor(xy[:, 1:2]) + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond + + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + # logger.info(f"FF sample number is {len(xy)}") + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1) + y = torch.Tensor(xy[:, 1]).reshape(-1, 1) + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=100): + xl = -torch.ones(n).reshape(-1, 1) + xr = torch.ones(n).reshape(-1, 1) + tl = torch.linspace(0, 1, n).reshape(-1, 1) + tr = torch.linspace(0, 1, n).reshape(-1, 1) + cond = -torch.ones(n).reshape(-1, 1) + # logger.info(f"sampling in interior, method is {method}") + return xr.requires_grad_(True), xl.requires_grad_(True), tr.requires_grad_(True), tl.requires_grad_(True), cond + + +def initial(n=100): + x = torch.linspace(-1., 1., n).reshape(-1, 1) + t = torch.zeros(n).reshape(-1, 1) + cond = x ** 2 * torch.cos(torch.pi * x) + # logger.info(f"sampling in interior, method is {method}") + return x.requires_grad_(True), t.requires_grad_(True), cond + +loss = torch.nn.MSELoss() + +import math + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + xr, xl, tr, tl, cond = boundary() + collocations['boundary'] = (xr, xl, tr, tl, cond) + xr, xl, tr, tl, cond = collocations['boundary'] + return loss(u(torch.cat([xr, tr], dim=1)) , cond)+loss(u(torch.cat([xl, tl], dim=1)) , cond) + + +def l_initial(u): + if 'initial' not in collocations: + x, t, cond = initial() + collocations['initial'] = (x, t, cond) + x, t, cond = collocations['initial'] + return loss(u(torch.cat([x, t], dim=1)), cond) + +def visualize_error(ax, u): + xx, tt, u_truth = exact_ac() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(201, 512) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(201, 512) + error = torch.abs(u_pred - u_truth) + xx, tt, _ = exact_ac() + ax.pcolormesh(xx.reshape(201, 512), tt.reshape(201, 512), error.detach().cpu().numpy(), vmin=0, vmax=0.3) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100) + tc = torch.linspace(0, 1, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + + u_pred = u_pred.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + ax.pcolormesh(xx, yy, u_pred, vmin=-1., vmax=1.) + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val, l_init_val): + return 2 * l_init_val + 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'mse: {l2_rel}') + + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + +def ff_rar(mem=0.9): + """ + RANG-m + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info(f"ff_rar_{mem:.2f}") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + +def hammersely_sample(): + """ + Hammersley + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersley_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + + l_interior_val = l_interior(u, method='hammersely') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + + +def lhs_resample(): + """ + LHS-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + opt.zero_grad() + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + + +def random_resample(): + """ + Random-R + """ + global exp_id + global collocations + + mse_list = [str(exp_id), f'random_resample_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample() # Random-R + exp_id += 1 diff --git a/A_allencahn/parser_pinn.py b/A_allencahn/parser_pinn.py new file mode 100644 index 0000000..316a989 --- /dev/null +++ b/A_allencahn/parser_pinn.py @@ -0,0 +1,40 @@ +import argparse +from datetime import datetime + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=50000, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--sample_num', default=1000, type=int + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 1] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + parser.add_argument( + '--resample_interval', default=1000 + ) + parser.add_argument( + '--freq_draw', default=5 + ) + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--sigma', default=0.1 + ) + parser.add_argument( + '--repeat', default=30 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + return parser diff --git a/B_wave/parser_pinn.py b/B_wave/parser_pinn.py new file mode 100644 index 0000000..eb9c703 --- /dev/null +++ b/B_wave/parser_pinn.py @@ -0,0 +1,49 @@ +import argparse +from datetime import datetime +import math + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=10000, type=int + ) + parser.add_argument( + '--resample_interval', default=1000, type=int + ) + parser.add_argument( + '--sample_num', default=1000, type=int + ) + parser.add_argument( + '--freq_draw', default=5, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 1] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--repeat', default=50 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + parser.add_argument( + '--a', default=2. + ) + parser.add_argument( + '--c', default=math.sqrt(3) + ) + parser.add_argument( + '--Tmax', default=6.0 + ) + parser.add_argument( + '--half_L', default=4 + ) + return parser diff --git a/B_wave/wave.py b/B_wave/wave.py new file mode 100644 index 0000000..1ee8f0f --- /dev/null +++ b/B_wave/wave.py @@ -0,0 +1,659 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = int(args.maxiter) +net_seq: list = list(args.net_seq) +sample_num: int = int(args.sample_num) +resample_interval: int = int(args.resample_interval) +freq_draw: int = int(args.freq_draw) +verbose: bool = bool(args.verbose) +resample_N: int = int(args.resample_N) +c: float = float(args.c) +a: float = float(args.a) +Tmax: float = float(args.Tmax) +half_L: float = float(args.half_L) + +# r_ff = 0.1 / np.sqrt(sample_num / 100) +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_wave(): + if len(_memo) == 0: + x = np.linspace(-half_L, half_L, 512, endpoint=False) + t = np.linspace(0, 1, 201) * Tmax + xx, tt = np.meshgrid(x, t) + u = 0.5 / np.cosh(a*(xx - c * tt)) \ + - 0.5 / np.cosh(a*(-half_L*2+xx + c * tt)) \ + + 0.5 / np.cosh(a*(xx + c * tt))\ + -0.5 / np.cosh(a*(xx+half_L*2 - c * tt)) + _memo.append((xx.reshape(-1, 1), tt.reshape(-1, 1), u.reshape(-1, 1))) + return _memo[0] + + +def compute_lhs(u, x, t): + u_pred = u(torch.cat([x, t], dim=1)) + u_xx = gradients(u_pred, x, 2) + u_tt = gradients(u_pred, t, 2) + lhs = u_tt-c**2*u_xx + return lhs + + +def compute_res(u): + resample_N = 500 + xc = torch.linspace(-1, 1, resample_N)*half_L + tc = torch.linspace(0, Tmax, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = torch.abs(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + + +def ddu_fn(x, y): + return torch.zeros_like(x) + + +def l2_loss(u): + xx, tt, u_truth = exact_wave() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum((u_pred - u_truth) ** 2) / torch.sum((u_truth ** 2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_wave() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mse = torch.mean((u_pred - u_truth) ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + + +# 定义区域及其上的采样 +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.)*half_L + yy = torch.rand(n, 1)*Tmax + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N)*half_L + tc = torch.linspace(0, 1, N)*Tmax + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond + + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + # logger.info(f"FF sample number is {len(xy)}") + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1)*half_L + y = torch.Tensor(xy[:, 1]).reshape(-1, 1)*Tmax + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=100): + xx = torch.cat([-torch.ones(n), torch.ones(n)]).reshape(-1, 1)*half_L + tt = torch.cat([torch.linspace(0, Tmax, n), torch.linspace(0, Tmax, n)]).reshape(-1, 1) + cond = torch.zeros_like(xx) + # logger.info(f"sampling in interior, method is {method}") + return xx.requires_grad_(True), tt.requires_grad_(True), cond + + +def initial(n=1000): + xx = torch.linspace(-1., 1., n).reshape(-1, 1)*half_L + tt = torch.zeros(n).reshape(-1, 1)*Tmax + cond = 0.5 / torch.cosh(a*(xx - c * tt)) \ + - 0.5 / torch.cosh(a*(-half_L*2+xx + c * tt)) \ + + 0.5 / torch.cosh(a*(xx + c * tt)) \ + -0.5 / torch.cosh(a*(xx+half_L*2 - c * tt)) + # logger.info(f"sampling in interior, method is {method}") + return xx.requires_grad_(True), tt.requires_grad_(True), cond + + +# 定义损失 +loss = torch.nn.MSELoss() + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + x, t, cond = boundary() + collocations['boundary'] = (x, t, cond) + x, t, cond = collocations['boundary'] + return loss(u(torch.cat([x, t], dim=1)), cond) + + +def l_initial(u): + if 'initial' not in collocations: + x, t, cond = initial() + collocations['initial'] = (x, t, cond) + x, t, cond = collocations['initial'] + pred_u = u(torch.cat([x, t], dim=1)) + u_t = gradients(pred_u,t) + return loss(pred_u, cond)+loss(u_t, torch.zeros_like(u_t)) + + +def visualize_error(ax, u): + xx, tt, u_truth = exact_wave() + shape = u_truth.shape + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(*shape) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(*shape) + error = torch.abs(u_pred - u_truth) + xx, tt, _ = exact_wave() + ax.pcolormesh(xx.reshape(*shape), tt.reshape(*shape), error.detach().cpu().numpy(), vmin=0, vmax=0.03) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100)*half_L + tc = torch.linspace(0, Tmax, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + + u_pred = u_pred.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + # ax.pcolormesh(xx, yy, np.abs(u_pred-u_truth), cmap='hot', vmin=0, vmax=1) + ax.pcolormesh(xx, yy, u_pred, vmin=-c/2, vmax=c/2, cmap='bwr') + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val, l_init_val): + return 2 * l_init_val + 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'mse: {l2_rel}') + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + return mse_list + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + return mse_list + + +def ff_rar(mem=0.9): + """ + RANG-m + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info(f"ff_rar_{mem:.2f}") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + return mse_list + + +def hammersely_sample(): + """ + Hammersley + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersely_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + + l_interior_val = l_interior(u, method='hammersely') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + return mse_list + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + return mse_list + + +def lhs_resample(): + """ + LHS-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + opt.zero_grad() + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + return mse_list + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + return mse_list + + +def random_resample(): + """ + Random-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + return mse_list + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample()# Random-R + exp_id += 1 \ No newline at end of file diff --git a/C_schodinger/NLS.mat b/C_schodinger/NLS.mat new file mode 100644 index 0000000..3767e62 Binary files /dev/null and b/C_schodinger/NLS.mat differ diff --git a/C_schodinger/parser_pinn.py b/C_schodinger/parser_pinn.py new file mode 100644 index 0000000..5ec30aa --- /dev/null +++ b/C_schodinger/parser_pinn.py @@ -0,0 +1,37 @@ +import argparse +from datetime import datetime +import math + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=50000, type=int + ) + parser.add_argument( + '--resample_interval', default=1000, type=int + ) + parser.add_argument( + '--sample_num', default=1000, type=int + ) + parser.add_argument( + '--freq_draw', default=5, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 2] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--repeat', default=30 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + return parser diff --git a/C_schodinger/schodinger.py b/C_schodinger/schodinger.py new file mode 100644 index 0000000..130a488 --- /dev/null +++ b/C_schodinger/schodinger.py @@ -0,0 +1,669 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from scipy.io import loadmat +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = int(args.maxiter) +net_seq: list = list(args.net_seq) +sample_num: int = int(args.sample_num) +resample_interval: int = int(args.resample_interval) +freq_draw: int = int(args.freq_draw) +verbose: bool = bool(args.verbose) +resample_N: int = int(args.resample_N) +half_L=5 +Tmax=np.pi / 2 + +# r_ff = 0.1 / np.sqrt(sample_num / 100) +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_schodinger(): + if len(_memo) == 0: + data = loadmat('NLS.mat') + tt = data['tt'] + x = data['x'] + u = data['uu'].T + t = tt.ravel() + x = x.ravel() + xx, tt = np.meshgrid(x, t) + u = np.concatenate([np.real(u)[:, :, np.newaxis], np.imag(u)[:, :, np.newaxis]], axis=-1) + _memo.append((xx.reshape(-1, 1), tt.reshape(-1, 1), u.reshape(-1, 2))) + return _memo[0] + + +def compute_lhs(u_net, x, t): + uv = u_net(torch.cat([x, t], dim=1)) + u = uv[:, 0:1] + v = uv[:, 1:2] + u_x = gradients(u, x, 1) + v_x = gradients(v, x, 1) + u_t = gradients(u, t, 1) + u_xx = gradients(u_x, x) + v_t = gradients(v, t) + v_xx = gradients(v_x, x) + res_v = u_t + 0.5 * v_xx + (u ** 2 + v ** 2) * v + res_u = v_t - 0.5 * u_xx - (u ** 2 + v ** 2) * u + lhs = torch.concat([res_v, res_u], dim=-1) + return lhs + + +def compute_res(u): + resample_N = 500 + xc = torch.linspace(-1, 1, resample_N)*half_L + tc = torch.linspace(0, Tmax, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = compute_mod(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + + + + +def ddu_fn(x, y): + return torch.zeros(len(x), 2) + + +def l2_loss(u): + xx, tt, u_truth = exact_schodinger() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum(compute_mod(u_pred - u_truth) ** 2) / torch.sum((compute_mod(u_truth) ** 2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_schodinger() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mod_difference =compute_mod(u_pred-u_truth) + mse = torch.mean(mod_difference ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + + +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.)*half_L + yy = torch.rand(n, 1)*Tmax + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N)*half_L + tc = torch.linspace(0, 1, N)*Tmax + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond +def compute_mod(x): + return torch.sqrt(x[:,0:1]**2+x[:,1:2]**2) + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + # logger.info(f"FF sample number is {len(xy)}") + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1)*half_L + y = torch.Tensor(xy[:, 1]).reshape(-1, 1)*Tmax + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=100): + xx = torch.cat([-torch.ones(n), torch.ones(n)]).reshape(-1, 1)*half_L + tt = torch.cat([torch.linspace(0, Tmax, n), torch.linspace(0, Tmax, n)]).reshape(-1, 1) + cond = torch.zeros(2*n, 2) + return xx.requires_grad_(True), tt.requires_grad_(True), cond + +def initial(n=1000): + xx = torch.linspace(-1., 1., n).reshape(-1, 1)*half_L + tt = torch.zeros(n).reshape(-1, 1)*Tmax + cond_real = 2 / torch.cosh(xx) + cond_img = torch.zeros(n,1) + cond = torch.concat([cond_real, cond_img], dim=-1) + return xx.requires_grad_(True), tt.requires_grad_(True), cond + +loss = torch.nn.MSELoss() + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + x, t, cond = boundary() + collocations['boundary'] = (x, t, cond) + x, t, cond = collocations['boundary'] + u_pred = u(torch.cat([x, t], dim=1)) + u_pred_u, u_pred_v = u_pred[:, :1], u_pred[:, 1:] + u_pred_u_d = gradients(u_pred_u, x) + u_pred_v_d = gradients(u_pred_v, x) + u_pred_d = torch.concat([u_pred_u_d, u_pred_v_d], dim=-1) + return loss(u_pred, cond)+loss(u_pred_d,cond) + +def l_initial(u): + if 'initial' not in collocations: + x, t, cond = initial() + collocations['initial'] = (x, t, cond) + x, t, cond = collocations['initial'] + pred_u = u(torch.cat([x, t], dim=1)) + return loss(pred_u, cond) + + +def visualize_error(ax, u): + xx, tt, u_truth = exact_schodinger() + shape = u_truth.shape + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(*shape) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(*shape) + error = compute_mod(u_pred-u_truth) + xx, tt, _ = exact_schodinger() + ax.pcolormesh(xx.reshape(*shape), tt.reshape(*shape), error.detach().cpu().numpy(), vmin=0, vmax=0.03) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100)*half_L + tc = torch.linspace(0, Tmax, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + u_pred_mod = compute_mod(u_pred) + + u_pred_mod = u_pred_mod.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + # ax.pcolormesh(xx, yy, np.abs(u_pred-u_truth), cmap='hot', vmin=0, vmax=1) + ax.pcolormesh(xx, yy, u_pred_mod, vmin=0, vmax=4., cmap='YlGnBu') + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val, l_init_val): + return 2 * l_init_val + 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'l2_rel: {l2_rel}') + + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + return mse_list + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + return mse_list + + +def ff_rar(mem=0.9): + """ + RANG-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info(f"ff_rar_{mem:.2f}") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + return mse_list + + +def hammersely_sample(): + """ + Hammersley + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersely_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + + l_interior_val = l_interior(u, method='hammersely') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + return mse_list + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + return mse_list + + +def lhs_resample(): + """ + LHS_R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + opt.zero_grad() + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + return mse_list + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + return mse_list + + +def random_resample(): + """ + Random-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + return mse_list + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample()# Random-R + exp_id += 1 \ No newline at end of file diff --git a/D_kdv/kdv.py b/D_kdv/kdv.py new file mode 100644 index 0000000..f9a1caa --- /dev/null +++ b/D_kdv/kdv.py @@ -0,0 +1,648 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = int(args.maxiter) +net_seq: list = list(args.net_seq) +sample_num: int = int(args.sample_num) +resample_interval: int = int(args.resample_interval) +freq_draw: int = int(args.freq_draw) +verbose: bool = bool(args.verbose) +resample_N: int = int(args.resample_N) +c: float = float(args.c) +Tmax: float = float(args.Tmax) +half_L: float = float(args.half_L) + +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_kdv(): + if len(_memo) == 0: + x = np.linspace(-half_L, half_L, 512, endpoint=False) + t = np.linspace(0, 1, 201) * Tmax + xx, tt = np.meshgrid(x, t) + u = c / 2 / np.cosh(np.sqrt(c) / 2 * (xx+c - c * tt)) ** 2 + _memo.append((xx.reshape(-1, 1), tt.reshape(-1, 1), u.reshape(-1, 1))) + return _memo[0] + + +def compute_lhs(u, x, t): + u_pred = u(torch.cat([x, t], dim=1)) + u_x = gradients(u_pred, x, 1) + u_xxx = gradients(u_pred, x, 3) + u_t = gradients(u_pred, t, 1) + lhs = u_t + 6 * u_pred * u_x + u_xxx + return lhs + + +def compute_res(u): + resample_N = 500 + xc = torch.linspace(-1, 1, resample_N)*half_L + tc = torch.linspace(0, Tmax, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = torch.abs(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + +def ddu_fn(x, y): + return torch.zeros_like(x) + + +def l2_loss(u): + xx, tt, u_truth = exact_kdv() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum((u_pred - u_truth) ** 2) / torch.sum((u_truth ** 2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_kdv() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mse = torch.mean((u_pred - u_truth) ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + + +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.)*half_L + yy = torch.rand(n, 1)*Tmax + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N)*half_L + tc = torch.linspace(0, 1, N)*Tmax + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_L, torch.Tensor(xy[:, 1:2])*Tmax + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond + + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + # logger.info(f"FF sample number is {len(xy)}") + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1)*half_L + y = torch.Tensor(xy[:, 1]).reshape(-1, 1)*Tmax + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=100): + xx = torch.cat([-torch.ones(n), torch.ones(n)]).reshape(-1, 1)*half_L + tt = torch.cat([torch.linspace(0, Tmax, n), torch.linspace(0, Tmax, n)]).reshape(-1, 1) + cond = c / 2 / torch.cosh(np.sqrt(c) / 2 * (xx+c - c * tt)) ** 2 + # logger.info(f"sampling in interior, method is {method}") + return xx.requires_grad_(True), tt.requires_grad_(True), cond + + +def initial(n=1000): + xx = torch.linspace(-1., 1., n).reshape(-1, 1)*half_L + tt = torch.zeros(n).reshape(-1, 1)*Tmax + cond = c / 2 / torch.cosh(np.sqrt(c) / 2 * (xx+c - c * tt)) ** 2 + # logger.info(f"sampling in interior, method is {method}") + return xx.requires_grad_(True), tt.requires_grad_(True), cond + + +loss = torch.nn.MSELoss() + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + x, t, cond = boundary() + collocations['boundary'] = (x, t, cond) + x, t, cond = collocations['boundary'] + return loss(u(torch.cat([x, t], dim=1)), cond) + + +def l_initial(u): + if 'initial' not in collocations: + x, t, cond = initial() + collocations['initial'] = (x, t, cond) + x, t, cond = collocations['initial'] + return loss(u(torch.cat([x, t], dim=1)), cond) + + +def visualize_error(ax, u): + xx, tt, u_truth = exact_kdv() + shape = u_truth.shape + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(*shape) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(*shape) + error = torch.abs(u_pred - u_truth) + xx, tt, _ = exact_kdv() + ax.pcolormesh(xx.reshape(*shape), tt.reshape(*shape), error.detach().cpu().numpy(), vmin=0, vmax=0.03) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100)*half_L + tc = torch.linspace(0, Tmax, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + + u_pred = u_pred.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + # ax.pcolormesh(xx, yy, np.abs(u_pred-u_truth), cmap='hot', vmin=0, vmax=1) + ax.pcolormesh(xx, yy, u_pred, vmin=0., vmax=c/2) + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val, l_init_val): + return 2 * l_init_val + 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'mse: {l2_rel}') + + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + return mse_list + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + return mse_list + + +def ff_rar(mem=0.9): + """ + RANG-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info(f"ff_rar_{mem:.2f}") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + return mse_list + + +def hammersely_sample(): + """ + Hammersley + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersely_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + + l_interior_val = l_interior(u, method='hammersely') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + return mse_list + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + return mse_list + + +def lhs_resample(): + """ + LHS-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + opt.zero_grad() + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + return mse_list + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + return mse_list + + +def random_resample(): + """ + Random-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + return mse_list + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample()# Random-R + exp_id += 1 \ No newline at end of file diff --git a/D_kdv/parser_pinn.py b/D_kdv/parser_pinn.py new file mode 100644 index 0000000..5f3d2f0 --- /dev/null +++ b/D_kdv/parser_pinn.py @@ -0,0 +1,51 @@ +import argparse +from datetime import datetime +import math + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=50000, type=int + ) + parser.add_argument( + '--resample_interval', default=1000, type=int + ) + parser.add_argument( + '--sample_num', default=1000, type=int + ) + parser.add_argument( + '--freq_draw', default=5, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 1] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + + + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--sigma', default=0.1 + ) + parser.add_argument( + '--repeat', default=50 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + parser.add_argument( + '--c', default=7.0 + ) + parser.add_argument( + '--Tmax', default=2.0 + ) + parser.add_argument( + '--half_L', default=4*math.pi + ) + return parser diff --git a/E_poisson/parser_pinn.py b/E_poisson/parser_pinn.py new file mode 100644 index 0000000..d8fb8f6 --- /dev/null +++ b/E_poisson/parser_pinn.py @@ -0,0 +1,40 @@ +import argparse +from datetime import datetime + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=1000, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--sample_num', default=400, type=int + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 1] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + parser.add_argument( + '--resample_interval', default=100 + ) + parser.add_argument( + '--freq_draw', default=1 + ) + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--sigma', default=0.1 + ) + parser.add_argument( + '--repeat', default=100 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + return parser diff --git a/E_poisson/poisson.py b/E_poisson/poisson.py new file mode 100644 index 0000000..39b2978 --- /dev/null +++ b/E_poisson/poisson.py @@ -0,0 +1,630 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = args.maxiter +net_seq: list = args.net_seq +sample_num: int = args.sample_num +resample_interval: int = args.resample_interval +freq_draw: int = args.freq_draw +verbose: bool = args.verbose +sigma: float = args.sigma +resample_N: int = args.resample_N + +# r_ff = 0.1 / np.sqrt(sample_num / 100) +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_poisson(): + if len(_memo) == 0: + xc = np.linspace(-1, 1, 401) + yc = np.linspace(-1, 1, 401) + xx, yy = np.meshgrid(xc, yc) + u = np.exp(-((xx - 0.3) ** 2 + (yy - 0.3) ** 2) / 2 / sigma ** 2) \ + - np.exp(-((xx + 0.3) ** 2 + (yy + 0.3) ** 2) / 2 / sigma ** 2) + _memo.append((xx.reshape(-1, 1), yy.reshape(-1, 1), u.reshape(-1, 1))) + return _memo[0] + + +def compute_lhs(u, x, t): + u_pred = u(torch.cat([x, t], dim=1)) + u_xx = gradients(u_pred, x, 2) + u_yy = gradients(u_pred, t, 2) + lhs = u_xx + u_yy + return lhs + + +def compute_res(u): + xc = torch.linspace(-1, 1, resample_N) + tc = torch.linspace(-1, 1, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = torch.abs(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + +def ddu_fn(x, y): + rhs = (-1 + (x - 0.3) ** 2 / sigma ** 2) * torch.exp( + -((x - 0.3) ** 2 + (y - 0.3) ** 2) / (2 * sigma ** 2)) / sigma ** 2 \ + + (-1 + (y - 0.3) ** 2 / sigma ** 2) * torch.exp( + -((x - 0.3) ** 2 + (y - 0.3) ** 2) / (2 * sigma ** 2)) / sigma ** 2 \ + - (-1 + (x + 0.3) ** 2 / sigma ** 2) * torch.exp( + -((x + 0.3) ** 2 + (y + 0.3) ** 2) / (2 * sigma ** 2)) / sigma ** 2 \ + - (-1 + (y + 0.3) ** 2 / sigma ** 2) * torch.exp( + -((x + 0.3) ** 2 + (y + 0.3) ** 2) / (2 * sigma ** 2)) / sigma ** 2 + return rhs + + +def l2_loss(u): + xx, tt, u_truth = exact_poisson() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum((u_pred - u_truth) ** 2) / torch.sum((u_truth ** 2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_poisson() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mse = torch.mean((u_pred - u_truth) ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.) + yy = (torch.rand(n, 1) * 2 - 1.) + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N) + tc = torch.linspace(-1, 1, N) + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1), torch.Tensor(xy[:, 1:2] * 2 - 1) + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1), torch.Tensor(xy[:, 1:2] * 2 - 1) + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond + + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1) + y = torch.Tensor(xy[:, 1] * 2 - 1).reshape(-1, 1) + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=1000, method='random'): + x = torch.cat([-torch.ones(n), + torch.linspace(-1, 1, n), + torch.ones(n), + torch.linspace(-1, 1, n)]).reshape(-1, 1) + y = torch.cat([torch.linspace(-1, 1, n), + -torch.ones(n), + torch.linspace(-1, 1, n), + torch.ones(n)]).reshape(-1, 1) + + rhs = torch.exp(-((x - 0.3) ** 2 + (y - 0.3) ** 2) / 2 / sigma ** 2) \ + - torch.exp(-((x + 0.3) ** 2 + (y + 0.3) ** 2) / 2 / sigma ** 2) + return x.requires_grad_(True), y.requires_grad_(True), rhs + + +loss = torch.nn.MSELoss() + + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + x, y, cond = boundary() + collocations['boundary'] = (x, y, cond) + x, y, cond = collocations['boundary'] + return loss(u(torch.cat([x, y], dim=1)), cond) + +def visualize_error(ax, u): + xx, tt, u_truth = exact_poisson() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(401, 401) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(401, 401) + error = torch.abs(u_pred - u_truth) + xx, tt, _ = exact_poisson() + ax.pcolormesh(xx.reshape(401, 401), tt.reshape(401, 401), error.detach().cpu().numpy(), vmin=0, vmax=0.3) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100) + tc = torch.linspace(-1, 1, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + + u_pred = u_pred.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + # ax.pcolormesh(xx, yy, np.abs(u_pred-u_truth), cmap='hot', vmin=0, vmax=1) + ax.pcolormesh(xx, yy, u_pred, vmin=-1., vmax=1.) + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val): + return 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'mse: {l2_rel}') + + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + return mse_list + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + return mse_list + +def ff_rar(mem=0.9): + """ + RANG-m + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("FF rar") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + return mse_list + + +def hammersely_sample(): + """ + Hammersely + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersely_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + l = compose_loss(l_interior(u, method='hammersely'), l_boundary(u)) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + return mse_list + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + return mse_list + + +def lhs_resample(): + """ + LHS-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + l_boundary_val = l_boundary(u) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + return mse_list + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + return mse_list + + +def random_resample(): + """ + Random-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + l_boundary_val = l_boundary(u) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + return mse_list + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample() # Random-R + exp_id += 1 \ No newline at end of file diff --git a/F_adv_diffuse/adv_diffuse.py b/F_adv_diffuse/adv_diffuse.py new file mode 100644 index 0000000..68e45cd --- /dev/null +++ b/F_adv_diffuse/adv_diffuse.py @@ -0,0 +1,652 @@ +""" + +""" +import torch +from tools import gradients, MLP, logger +import matplotlib.pyplot as plt +import numpy as np +from ff import error_ff, hammersely, lhs +from parser_pinn import get_parser +import pathlib + +parser_PINN = get_parser() +args = parser_PINN.parse_args() +path = pathlib.Path(args.save_path) +path.mkdir(exist_ok=True, parents=True) +for key, val in vars(args).items(): + print(f"{key} = {val}") +with open(path.joinpath('config'), 'wt') as f: + f.writelines([f"{key} = {val}\n" for key, val in vars(args).items()]) +maxiter: int = int(args.maxiter) +net_seq: list = list(args.net_seq) +sample_num: int = int(args.sample_num) +resample_interval: int = int(args.resample_interval) +freq_draw: int = int(args.freq_draw) +verbose: bool = bool(args.verbose) +resample_N: int = int(args.resample_N) +t_start = float(args.t_start) +alpha = float(args.alpha) +Tmax= float(args.Tmax) +half_x = float(args.half_L) + +resample_num = maxiter // resample_interval +log_interval = maxiter // 10 +rar_interval = maxiter // resample_num + +# todo more careful check +GPU_ENABLED = True +if torch.cuda.is_available(): + try: + _ = torch.Tensor([0., 0.]).cuda() + torch.set_default_tensor_type('torch.cuda.FloatTensor') + print('gpu available') + GPU_ENABLED = True + except: + print('gpu not available') + GPU_ENABLED = False +else: + print('gpu not available') + GPU_ENABLED = False + +_memo = [] + + +def exact_heat(): + if len(_memo) == 0: + xc = np.linspace(-1, 1, 501)*half_x + tc = np.linspace(0, 1, 501)*Tmax + xx, tt = np.meshgrid(xc, tc) + u = 0.1 / np.sqrt(alpha * (tt + t_start)) * np.exp(-(xx+2-4*tt) ** 2 / (4 * alpha * (tt + t_start))) + _memo.append((xx.reshape(-1, 1), tt.reshape(-1, 1), u.reshape(-1, 1))) + return _memo[0] + + +def compute_lhs(u, x, t): + u_pred = u(torch.cat([x, t], dim=1)) + u_x = gradients(u_pred, x, 1) + u_xx = gradients(u_x, x, 1) + u_t = gradients(u_pred, t, 1) + lhs = u_t + 4*u_x - alpha*u_xx + return lhs + + +def compute_res(u): + resample_N = 500 + xc = torch.linspace(-1, 1, resample_N)*half_x + tc = torch.linspace(0, Tmax, resample_N) + + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + + x = torch.Tensor(xy[:, 0]).reshape(-1, 1).requires_grad_(True) + t = torch.Tensor(xy[:, 1]).reshape(-1, 1).requires_grad_(True) + + lhs = compute_lhs(u, x, t) + residual = torch.abs(lhs) + error = residual.reshape(resample_N, resample_N).detach().cpu().numpy() + + xc = np.linspace(-1, 1, resample_N) + tc = np.linspace(-1, 1, resample_N) + xx, yy = np.meshgrid(xc, tc, indexing='xy') + + return xx, yy, error + + +def ddu_fn(x, y): + return torch.zeros_like(x) + + +def l2_loss(u): + xx, tt, u_truth = exact_heat() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + with torch.no_grad(): + u_pred = u(xy) + l2_error = torch.sqrt( + torch.sum((u_pred - u_truth) ** 2) / torch.sum((u_truth ** 2))) + return l2_error.detach().cpu().numpy() + + +def several_error(u): + with torch.no_grad(): + xx, tt, u_truth = exact_heat() + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth) + xy = torch.cat([xx, tt], dim=1) + u_pred = u(xy) + mse = torch.mean((u_pred - u_truth) ** 2).detach().cpu().numpy() + + return mse + + +torch.random.seed() + + +def interior(n=sample_num, method='random'): + if method == 'random': + xx = (torch.rand(n, 1) * 2 - 1.)*half_x + yy = torch.rand(n, 1)*Tmax + elif method == 'uniform': + N = int(np.sqrt(n)) + xc = torch.linspace(-1, 1, N)*half_x + tc = torch.linspace(0, 1, N)*Tmax + xx, yy = torch.meshgrid(xc, tc) + xx = xx.ravel().reshape(-1, 1) + yy = yy.ravel().reshape(-1, 1) + elif method == "hammersely": + xy = hammersely(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_x, torch.Tensor(xy[:, 1:2])*Tmax + elif method == "lhs": + xy = lhs(n) + xx, yy = torch.Tensor(xy[:, 0:1] * 2 - 1)*half_x, torch.Tensor(xy[:, 1:2])*Tmax + cond = ddu_fn(xx, yy) + return xx.requires_grad_(True), yy.requires_grad_(True), cond + + +def interior_ff(error, n=sample_num, verbose=False): + xy = error_ff(n, + error=error, + max_min_density_ratio=10, + box=[0, 1, 0, 1]) + if verbose: + logger.info("length of samples is {}".format(len(xy))) + # logger.info(f"FF sample number is {len(xy)}") + x = torch.Tensor(xy[:, 0] * 2 - 1).reshape(-1, 1)*half_x + y = torch.Tensor(xy[:, 1]).reshape(-1, 1)*Tmax + cond = ddu_fn(x, y) + + x = x.requires_grad_(True) + y = y.requires_grad_(True) + return x, y, cond + + +def boundary(n=100): + xl = -torch.ones(n).reshape(-1, 1)*half_x + xr = torch.ones(n).reshape(-1, 1)*half_x + tl = torch.linspace(0, Tmax, n).reshape(-1, 1) + tr = torch.linspace(0, Tmax, n).reshape(-1, 1) + cond = torch.zeros_like(xl) + # logger.info(f"sampling in interior, method is {method}") + return xr.requires_grad_(True), xl.requires_grad_(True), tr.requires_grad_(True), tl.requires_grad_(True), cond + +def initial(n=1000): + xx = torch.linspace(-1., 1., n).reshape(-1, 1)*half_x + tt = torch.zeros(n).reshape(-1, 1) + cond = 0.1 / torch.sqrt(alpha * (tt + t_start)) * torch.exp(-(xx+2-4*tt) ** 2 / (4 * alpha * (tt + t_start))) + # logger.info(f"sampling in interior, method is {method}") + return xx.requires_grad_(True), tt.requires_grad_(True), cond + + +loss = torch.nn.MSELoss() + +import math + + +def l_interior(u, method='random', resample=False): + if resample or ('interior' not in collocations): + x, t, cond = interior(method=method) + collocations['interior'] = (x, t, cond) + x, t, cond = collocations['interior'] + lhs = compute_lhs(u, x, t) + l = loss(lhs, cond) + return l + + +def l_boundary(u): + if 'boundary' not in collocations: + xr, xl, tr, tl, cond = boundary() + collocations['boundary'] = (xr, xl, tr, tl, cond) + xr, xl, tr, tl, cond = collocations['boundary'] + return loss(u(torch.cat([xr, tr], dim=1)), cond) + loss(u(torch.cat([xl, tl], dim=1)), cond) + + +def l_initial(u): + if 'initial' not in collocations: + x, t, cond = initial() + collocations['initial'] = (x, t, cond) + x, t, cond = collocations['initial'] + return loss(u(torch.cat([x, t], dim=1)), cond) + + +def visualize_error(ax, u): + xx, tt, u_truth = exact_heat() + shape = u_truth.shape + xx = torch.Tensor(xx) + tt = torch.Tensor(tt) + u_truth = torch.Tensor(u_truth).reshape(*shape) + + xx = xx.reshape(-1, 1) + yy = tt.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + with torch.no_grad(): + u_pred = u(xy).reshape(*shape) + error = torch.abs(u_pred - u_truth) + xx, tt, _ = exact_heat() + ax.pcolormesh(xx.reshape(*shape), tt.reshape(*shape), error.detach().cpu().numpy(), vmin=0, vmax=0.03) + ax.set_title('abs error') + + +def visualize(ax, u, verbose=True): + xc = torch.linspace(-1, 1, 100)*half_x + tc = torch.linspace(0, Tmax, 100) + xx, yy = torch.meshgrid(xc, tc, indexing='xy') + xx = xx.reshape(-1, 1) + yy = yy.reshape(-1, 1) + xy = torch.cat([xx, yy], dim=1) + u_pred = u(xy) + if verbose: + logger.info("L2 error is: {}".format(float(l2_loss(u)))) + + u_pred = u_pred.detach().cpu().numpy().reshape(100, 100) + + xx = xx.detach().cpu().numpy().reshape(100, 100) + yy = yy.detach().cpu().numpy().reshape(100, 100) + + # ax.pcolormesh(xx, yy, np.abs(u_pred-u_truth), cmap='hot', vmin=0, vmax=1) + ax.pcolormesh(xx, yy, u_pred, vmin=0., vmax=2., cmap='hot') + ax.set_title('Prediction') + + +def visualize_scatter(ax, collocation): + x, y, _ = collocation['interior'] + x = x.detach().cpu().numpy().ravel() + y = y.detach().cpu().numpy().ravel() + ax.scatter(x, y, s=1) + + +def compose_loss(l_interior_val, l_boundary_val, l_init_val): + return 2 * l_init_val + 2 * l_boundary_val + 0.2 * l_interior_val + + +def write_res(mse_list): + with open(path.joinpath('result.csv'), "a+") as f: + f.write(', '.join(mse_list)) + f.write('\n') + + +def eval_u(u, mse_list): + l2_rel = several_error(u) + mse_list.append(str(l2_rel)) + logger.info(f'mse: {l2_rel}') + + +def ff(): + """ + FF + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + collocations['interior'] = interior_ff(np.ones((100, 100))) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + opt.zero_grad() + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff.pth')) + write_res(mse_list) + return mse_list + + +def ff_resample(): + """ + FF-R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("ff resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + collocations['interior'] = interior_ff(np.ones((100, 100))) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_re.pth')) + write_res(mse_list) + return mse_list + + +def ff_rar(mem=0.9): + """ + RANG-m + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'ff_rar_{mem:.2f}_mse'] + + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info(f"ff_rar_{mem:.2f}") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters()) + collocations['interior'] = interior_ff(np.ones((100, 100))) + + for i in range(maxiter): + opt.zero_grad() + + if i > 0 and i % rar_interval == 0: + xx, yy, new_error = compute_res(u) + min_v = np.min(new_error) + max_v = np.max(new_error) + new_error = (new_error - min_v) / (max_v - min_v + 1e-8) + try: + error = np.maximum(mem * error, new_error) + except: + error = new_error + collocations['interior'] = interior_ff(error, sample_num) + + if verbose: + logger.info("length of samples is {}".format(len(collocations['interior'][0]))) + + if sample_idx % freq_draw == 0: + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u) + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}, point num is {len(collocations["interior"][0])}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'ff_rar_{mem:.2f}.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'ff_rar_{mem:.2f}.pth')) + write_res(mse_list) + + return mse_list + + +def hammersely_sample(): + """ + Hammersley + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'hammersely_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("hammersely sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + opt.zero_grad() + + l_interior_val = l_interior(u, method='hammersely') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'hammersely_evo_ac.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'hammersely_evo_ac.pth')) + write_res(mse_list) + return mse_list + + +def lhs_sample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs sampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='lhs') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs.pth')) + write_res(mse_list) + return mse_list + + +def lhs_resample(): + """ + LHS + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'lhs_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("lhs resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='lhs', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='lhs') + opt.zero_grad() + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'lhs_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'lhs_re.pth')) + write_res(mse_list) + return mse_list + + +def random(): + """ + Random + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random.pth')) + write_res(mse_list) + return mse_list + + +def random_resample(): + """ + Random_R + """ + global exp_id + global collocations + mse_list = [str(exp_id), f'random_resample_mse'] + fig2, ax2 = plt.subplots(resample_num // freq_draw, 3, figsize=(12, resample_num // freq_draw * 4)) + fig2.set_tight_layout(True) + sample_idx = 0 + logger.info("random resampling") + collocations = dict() + u = MLP(seq=net_seq) + opt = torch.optim.Adam(params=u.parameters(), lr=0.001) + for i in range(maxiter): + if i > 0 and i % rar_interval == 0: + l_interior_val = l_interior(u, method='random', resample=True) + if sample_idx % freq_draw == 0: + xx, yy, error = compute_res(u) + visualize(ax2[sample_idx // freq_draw, 0], u, verbose=verbose) + visualize_scatter(ax2[sample_idx // freq_draw, 1], collocations) + ax2[sample_idx // freq_draw, 2].pcolormesh(xx, yy, error) + eval_u(u, mse_list=mse_list) + sample_idx += 1 + else: + l_interior_val = l_interior(u, method='random') + l_boundary_val = l_boundary(u) + l_initial_val = l_initial(u) + l = compose_loss(l_interior_val, l_boundary_val, l_initial_val) + opt.zero_grad() + l.backward() + opt.step() + if i % log_interval == 0: + logger.info(f'iteration {i}: loss is {float(l)}') + eval_u(u, mse_list=mse_list) + fig2.savefig(path.joinpath(f'random_re.png')) + plt.close(fig2) + torch.save(u.state_dict(), path.joinpath(f'random_re.pth')) + write_res(mse_list) + return mse_list + + +if __name__ == '__main__': + exp_id = int(args.start_epoch) + for i in range(int(args.repeat)): + ff() # FF + exp_id += 1 + + ff_resample() # FF-R + exp_id += 1 + + ff_rar(0.9) # RANG-m + exp_id += 1 + + ff_rar(0.0) # RANG + exp_id += 1 + + hammersely_sample() # Hammersley + exp_id += 1 + + lhs_sample() # LHS + exp_id += 1 + + lhs_resample() # LHS-R + exp_id += 1 + + random() # Random + exp_id += 1 + + random_resample() # Random-R + exp_id += 1 \ No newline at end of file diff --git a/F_adv_diffuse/parser_pinn.py b/F_adv_diffuse/parser_pinn.py new file mode 100644 index 0000000..d5777e3 --- /dev/null +++ b/F_adv_diffuse/parser_pinn.py @@ -0,0 +1,52 @@ +import argparse +from datetime import datetime + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--maxiter', default=10000, type=int + ) + parser.add_argument( + '--resample_N', default=100 + ) + parser.add_argument( + '--sample_num', default=1000, type=int + ) + parser.add_argument( + '--net_seq', default=[2, 64, 64, 64, 64, 1] + ) + parser.add_argument( + '--save_path', default=f'./data/{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}' + ) + parser.add_argument( + '--resample_interval', default=1000 + ) + parser.add_argument( + '--freq_draw', default=1 + ) + parser.add_argument( + '--verbose', default=False + ) + parser.add_argument( + '--sigma', default=0.1 + ) + parser.add_argument( + '--repeat', default=60 + ) + parser.add_argument( + '--start_epoch', default=0 + ) + parser.add_argument( + '--t_start', default=0.1 + ) + parser.add_argument( + '--alpha', default=0.05 + ) + parser.add_argument( + '--Tmax', default=1.0 + ) + parser.add_argument( + '--half_L', default=4.0 + ) + return parser diff --git a/ff.py b/ff.py new file mode 100644 index 0000000..c7e5bf7 --- /dev/null +++ b/ff.py @@ -0,0 +1,143 @@ +import numpy as np +import math +from pyDOE import lhs as _lhs + +def scatter_halftone(box, ninit, dotmax, radius): + lb = box[0] + rb = box[1] + db = box[2] + ub = box[3] + count=0 + dotnr = -1 + N_PDP_MAX = 100000 + pdp_x = np.zeros(N_PDP_MAX) + pdp_y = np.zeros(N_PDP_MAX) + + pdp_x[:ninit] = np.linspace(lb, rb, ninit) + pdp_y[:ninit] = np.random.rand(ninit) * 1e-4 + db + + pdp_num = ninit + xy = np.zeros((dotmax, 2)) + i = np.argmin(pdp_y[:ninit]) + ym = pdp_y[i] + fan = np.linspace(0.1, 0.9, 5) + while ym <= ub and dotnr < dotmax: + dotnr += 1 + xy[dotnr, 0] = pdp_x[i] + xy[dotnr, 1] = pdp_y[i] + r = radius(xy[dotnr, :]) + dist2 = (pdp_x[:pdp_num] - pdp_x[i]) ** 2 + (pdp_y[:pdp_num] - pdp_y[i]) ** 2 + + ileft = np.where(dist2[:i] > r ** 2) + if len(ileft[0]) == 0: + ileft = -1 + ang_left = np.pi + else: + ileft = max(ileft[0]) + ang_left = np.arctan2(pdp_y[ileft] - pdp_y[i], pdp_x[ileft] - pdp_x[i]) + + iright = np.where(dist2[i:pdp_num] > r ** 2) + if len(iright[0]) == 0: + iright = -1 + ang_right = 0 + else: + iright = min(iright[0]) + ang_right = np.arctan2(pdp_y[i + iright] - pdp_y[i], pdp_x[i + iright] - pdp_x[i]) + ang = ang_left - fan * (ang_left - ang_right) + pdp_new_x = pdp_x[i] + r * np.cos(ang) + pdp_new_y = pdp_y[i] + r * np.sin(ang) + ind = np.logical_and(pdp_new_x <= rb, pdp_new_x >= lb) + pdp_new_x = pdp_new_x[ind] + pdp_new_y = pdp_new_y[ind] + new_add = len(pdp_new_x) + if iright ==-1 and ileft == -1: + removed = pdp_num + elif iright ==-1: + removed = pdp_num-ileft-1 + elif ileft == -1: + removed = iright-1+i-ileft + else: + removed = i-ileft+iright-1 + if iright!=-1: + pdp_x[iright + i + new_add - removed:pdp_num + new_add - removed] = pdp_x[iright + i:pdp_num] + pdp_y[iright + i + new_add - removed:pdp_num + new_add - removed] = pdp_y[iright + i:pdp_num] + + pdp_x[ileft + 1:ileft + 1 + new_add] = pdp_new_x + pdp_y[ileft + 1:ileft + 1 + new_add] = pdp_new_y + + pdp_num = pdp_num + new_add - removed + i = np.argmin(pdp_y[:pdp_num]) + ym = pdp_y[i] + + xy = xy[:dotnr, :] + return xy + +def error_ff(target_num, error, max_min_density_ratio=20, box=None, sdf=None): + _N = len(error) + if box is None: + box = [0, 1, 0, 1] + if sdf is None: + sdf = lambda x: np.ones((len(x), 1)) + error = -error + error_min = np.min(error) + error_max = np.max(error) + + error = ((error - error_min) / ((error_max - error_min) + 1e-10)) + + min_scale = 0.02 + max_scale = 1. + scale = (min_scale + max_scale) / 2 + + def r(xy): + ixy = np.asarray(np.round(xy * (_N - 1)), dtype=int) + return (error[ixy[1], ixy[0]] * (1 - 1 / math.sqrt(max_min_density_ratio)) + 1 / math.sqrt( + max_min_density_ratio)) * scale + + xy = scatter_halftone(box, 100, 10000, r) + len_xy = len(xy) + while np.abs(len_xy - target_num) / target_num > 0.05 and max_scale - min_scale > 0.003: + if target_num > len_xy: + max_scale = scale + else: + min_scale = scale + scale = (max_scale + min_scale) / 2 + + def r(xy): + ixy = np.asarray(np.round(xy * (_N - 1)), dtype=int) + return (error[ixy[1], ixy[0]] * (1 - 1 / math.sqrt(max_min_density_ratio)) + 1 / math.sqrt( + max_min_density_ratio)) * scale + + xy = scatter_halftone([0, 1, 0, 1], 100, 10000, r) + xy = xy[sdf(xy).ravel() > 0, :] + len_xy = len(xy) + return xy + + +def halton(b): + """Generator function for Halton sequence.""" + n, d = 0, 1 + while True: + x = d - n + if x == 1: + n = 1 + d *= b + else: + y = d // b + while x <= y: + y //= b + n = (b + 1) * y - x + yield n / d + + +def hammersely(Nsize, p=2): + y = [] + for i, num in enumerate(halton(p)): + if i >= Nsize: + break + y.append(num) + x = np.arange(0, Nsize) / Nsize + return np.array([x, y]).T + +def lhs(Nsize): + xy = _lhs(2, Nsize) + return xy \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..1511929 --- /dev/null +++ b/readme.md @@ -0,0 +1,11 @@ +# Residual-based Adaptive Node Generattion (RANG) for PINN +Requirement: `pytorch`, `pyDOE`, `scipy` + +Usage: +```bash +git clone xxxxxx +cd rang_pinn +export PYTHONPATH=$PWD:$PYTHONPATH +cd A_allencahn +python ac.py --repeat=30 +``` \ No newline at end of file diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..f9baa4f --- /dev/null +++ b/tools.py @@ -0,0 +1,36 @@ +import torch +import logging + + +def gradients(u, x, order=1): + if order == 1: + return torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), + create_graph=True, + only_inputs=True, )[0] + else: + return gradients(gradients(u, x), x, order=order - 1) + + +# 定义网络 +class MLP(torch.nn.Module): + def __init__(self, seq=None): + super(MLP, self).__init__() + if seq is None: + seq = [1, 50, 50, 50, 50, 1] + seq = [(seq[i], seq[i + 1]) for i in range(len(seq) - 1)] + mod_seq = [] + for s in seq[:-1]: + mod_seq.append(torch.nn.Linear(s[0], s[1])) + mod_seq.append(torch.nn.Tanh()) + s = seq[-1] + mod_seq.append(torch.nn.Linear(s[0], s[1])) + self.net = torch.nn.Sequential(*mod_seq) + + def forward(self, x): + return self.net(x) + + +log_format = '[%(asctime)s] [%(levelname)s] %(message)s' +handlers = [logging.FileHandler('train.log', mode='a'), logging.StreamHandler()] +logging.basicConfig(format=log_format, level=logging.INFO, datefmt='%d-%b-%y %H:%M:%S', handlers=handlers) +logger = logging.getLogger(__name__)