Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
175dd6a
Bug fix
andrewrightjames Jun 26, 2025
4f7e2bb
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jun 26, 2025
5102249
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jun 28, 2025
fdd2134
KANs files + configs, as well as full model transformer config
andrewrightjames Jun 28, 2025
aeab0ba
bug fix
andrewrightjames Jun 28, 2025
5833d3a
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jun 28, 2025
4bed00f
Added the Convolution KAN and normal KAN layers and updated configs
andrewrightjames Jun 28, 2025
0035c07
Added FastKANs config + bug fixes
andrewrightjames Jun 28, 2025
8756c73
Removed layer norm for KANs and transformer + changed tools to includ…
andrewrightjames Jun 29, 2025
160afb3
Added BIG transformer
andrewrightjames Jun 29, 2025
cdb4b8d
Bug fix
andrewrightjames Jun 29, 2025
d8ba827
Bug fixes
andrewrightjames Jun 29, 2025
0b6e247
bug fixes + added layer norm to transformers
andrewrightjames Jun 30, 2025
308b266
config changes
andrewrightjames Jun 30, 2025
0617127
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jun 30, 2025
acc16e0
Added KANs and transformer to getModel
andrewrightjames Jun 30, 2025
fe0c312
Added weight decay and amsgrad to Adam
andrewrightjames Jun 30, 2025
aadc451
bug fix
andrewrightjames Jun 30, 2025
e958faf
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jul 21, 2025
5a56114
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jul 26, 2025
b2d9916
Added dual train script
andrewrightjames Jul 26, 2025
28a6a65
Bug fix
andrewrightjames Jul 26, 2025
299bc93
bug fix
andrewrightjames Jul 29, 2025
6bfc236
Merge branch 'main' of https://github.com/BigBalloon8/ml-accelerated-…
andrewrightjames Jul 29, 2025
4a81114
fix train_v2
andrewrightjames Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.nn as nn

from torch_cfd import grids, boundaries
from torch_cfd.initial_conditions import filtered_velocity_field

from torch_cfd.equations import stable_time_step
from torch_cfd.fvm import RKStepper, NavierStokes2DFVMProjection
from torch_cfd.forcings import KolmogorovForcing
import torch_cfd.finite_differences as fdm
import torch_cfd.tensor_utils as tensor_utils
import torch.utils._pytree as pytree

from tqdm import tqdm
import safetensors.torch as st

import argparse
import json
import os
from typing import Tuple

from models import MLP, CNN, Transformer
from src.models import FastKAN

def hash_dict(x:dict):
return str(hash(json.dumps(x, sort_keys=True)))

def get_model(name:str, config_file, checkpoint_path)-> Tuple[nn.Module, dict]:
with open(config_file, "r") as f:
config = json.load(f)

if name.upper() == "MLP":
model_base = MLP(config)
elif name.upper() == "CNN":
model_base = CNN(config)
elif name.upper() == "KAN":
model_base = FastKAN(config)
elif name.upper() == "TRANSFORMER":
model_base = Transformer(config)
else:
raise ValueError(f"Model type [{name}] not supported please select from |MLP|CNN|KAN|TRANSFORMER|")

if f"{name}_{hash_dict(config)}.safetensors" in os.listdir(checkpoint_path):
model_path = os.path.join(checkpoint_path, f"{name}_{hash_dict(config)}.safetensors")
model_weights = st.load_file(model_path)
metadata = model_weights.pop("__metadata__")
model_base.load_state_dict(model_weights)
else:
raise FileNotFoundError("Model weights for the given config are not in the checkpoint path")
return model_base, metadata


def main(model_type, model_config, checkpoint_path):
#--------------Simulation Setup-----------------
density = 1.0
max_velocity = 7.0
peak_wavenumber = 4.0
cfl_safety_factor = 0.5
viscosity = 1e-3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float64)
diam = 2 * torch.pi
simulation_time = 30

step_fn = RKStepper.from_method(method="classic_rk4", requires_grad=False, dtype=torch.float64)

coarse_grid = grids.Grid((64, 64), domain=((0, diam), (0, diam)), device=device)

dt = stable_time_step(
dx=min(coarse_grid.step),
max_velocity=max_velocity,
max_courant_number=cfl_safety_factor,
viscosity=viscosity,
)


v0 = filtered_velocity_field(
coarse_grid, max_velocity, peak_wavenumber, iterations=16, random_state=42,
device=device, batch_size=1,)
pressure_bc = boundaries.get_pressure_bc_from_velocity(v0)

forcing_fn = KolmogorovForcing(diam=diam, wave_number=int(peak_wavenumber),
grid=coarse_grid, offsets=(v0[0].offset, v0[1].offset))

ns2d = NavierStokes2DFVMProjection(
viscosity=viscosity,
grid=coarse_grid,
bcs=(v0[0].bc, v0[1].bc),
density=density,
drag=0.1,
forcing=forcing_fn,
solver=step_fn,
# set_laplacian=False,
).to(v0.device)

#-----------ML setup------------------
model, _ = get_model(model_type, model_config, checkpoint_path)
model.to(device)


for t in tqdm(range(round(simulation_time/dt))):
v, p = step_fn.forward(v, dt, equation=ns2d)
v += model(v)


if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--model_type", default="CNN", help="Model to train: [MLP, CNN, KAN, Transformer]")
ap.add_argument("--model_config", default="./model.config", help="path to model config")
ap.add_argument("--checkpoint_path", default=".", help="path to model config")
with torch.inference_mode():
main(**ap.parse_args().__dict__)
114 changes: 114 additions & 0 deletions src/models/FastKAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# taken from and based on https://github.com/IvanDrokin/torch-conv-kan/blob/main/kans/kan.py
# and https://github.com/1ssb/torchkan/blob/main/torchkan.py
# and https://github.com/1ssb/torchkan/blob/main/KALnet.py
# and https://github.com/ZiyaoLi/fast-kan/blob/master/fastkan/fastkan.py
# Copyright 2024 Li, Ziyao
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn
import torch
from .tools import structureLoader, getAct

class SplineLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
self.init_scale = init_scale
super().__init__(in_features, out_features, bias=False, **kw)

def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)

class RadialBasisFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
denominator: float = None, # larger denominators lead to smoother basis
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=True)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)


class FastKANLayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
use_base_update: bool = True,
base_activation=nn.SiLU,
spline_weight_init_scale: float = 0.1,
) -> None:
super().__init__()
#self.layernorm = nn.LayerNorm(input_dim)
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
self.use_base_update = use_base_update
if use_base_update:
self.base_activation = base_activation
self.base_linear = nn.Linear(input_dim, output_dim)

def forward(self, x, time_benchmark=False):
if not time_benchmark:
#spline_basis = self.rbf(self.layernorm(x)) # EDITED
spline_basis = self.rbf(x)
else:
spline_basis = self.rbf(x)
ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
if self.use_base_update:
base = self.base_linear(self.base_activation(x))
ret = ret + base
return ret

class FastKAN(nn.Module):
def __init__(self, config):
super(FastKAN, self).__init__()

self.structure = structureLoader(config["structures"])
self.dim = config.get("dimension", [1, 1])
self.base_activation = getAct(config.get("base_activation", "silu"))

self.grid_min, self.grid_max = config["grid_range"]
self.spline_weight_init_scale = config["spline_weight_init_scale"]
self.use_base_update = config["use_base_update"]
self.num_grids = config["num_grids"]

self.dropout = nn.Dropout(config.get("dropout", 0.1))

self.layers = nn.ModuleList([FastKANLayer(input_dim = self.structure[i]*self.dim[0]*self.dim[1],
output_dim = self.structure[i+1]*self.dim[0]*self.dim[1],
grid_min = self.grid_min,
grid_max = self.grid_max,
num_grids = self.num_grids,
use_base_update = self.use_base_update,
base_activation= self.base_activation,
spline_weight_init_scale = self.spline_weight_init_scale) for i in range(len(self.structure)-1)])

def forward(self, x):
input_shape = x.shape
x = x.flatten(1, -1)
for i, layer in enumerate(self.layers):
x = layer(x)
if i+1 != len(self.layers):
x = self.dropout(x)
x = x.reshape(input_shape)
return x

150 changes: 150 additions & 0 deletions src/models/FastKANConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# taken from and based on https://github.com/IvanDrokin/torch-conv-kan/blob/main/kan_convs/fast_kan_conv.py
import torch.nn as nn
import torch

from .tools import structureLoader, getAct

class RadialBasisFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
denominator: float = None, # larger denominators lead to smoother basis
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)


class FastKANConvNDLayer(nn.Module):
def __init__(self, conv_class, # norm_class, # EDITED
input_dim, output_dim, kernel_size,
groups=1, padding=0, stride=1, dilation=1,
ndim: int = 2, grid_size=8, base_activation=nn.SiLU, grid_range=[-2, 2], dropout=0.0): # EDITED: removed **norm_kwargs
super(FastKANConvNDLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
self.ndim = ndim
self.grid_size = grid_size
self.base_activation = base_activation
self.grid_range = grid_range
# self.norm_kwargs = norm_kwargs # EDITED

if groups <= 0:
raise ValueError('groups must be a positive integer')
if input_dim % groups != 0:
raise ValueError('input_dim must be divisible by groups')
if output_dim % groups != 0:
raise ValueError('output_dim must be divisible by groups')

self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
output_dim // groups,
kernel_size,
stride,
padding,
dilation,
groups=1,
bias=False) for _ in range(groups)])

self.spline_conv = nn.ModuleList([conv_class(grid_size * input_dim // groups,
output_dim // groups,
kernel_size,
stride,
padding,
dilation,
groups=1,
bias=False) for _ in range(groups)])

# self.layer_norm = nn.ModuleList([norm_class(input_dim // groups, **norm_kwargs) for _ in range(groups)]) # EDITED

self.rbf = RadialBasisFunction(grid_range[0], grid_range[1], grid_size)

self.dropout = None
if dropout > 0:
if ndim == 1:
self.dropout = nn.Dropout1d(p=dropout)
if ndim == 2:
self.dropout = nn.Dropout2d(p=dropout)
if ndim == 3:
self.dropout = nn.Dropout3d(p=dropout)

# Initialize weights using Kaiming uniform distribution for better training start
for conv_layer in self.base_conv:
nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')

for conv_layer in self.spline_conv:
nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')

def forward_fast_kan(self, x, group_index):

# Apply base activation to input and then linear transform with base weights
base_output = self.base_conv[group_index](self.base_activation(x))
if self.dropout is not None:
x = self.dropout(x)
# spline_basis = self.rbf(self.layer_norm[group_index](x)) # EDITED
spline_basis = self.rbf(x) # EDITED
spline_basis = spline_basis.moveaxis(-1, 2).flatten(1, 2)
spline_output = self.spline_conv[group_index](spline_basis)
x = base_output + spline_output

return x

def forward(self, x):
split_x = torch.split(x, self.inputdim // self.groups, dim=1)
output = []
for group_ind, _x in enumerate(split_x):
y = self.forward_fast_kan(_x, group_ind)
output.append(y.clone())
y = torch.cat(output, dim=1)
return y


class FastKANConvND(nn.Module):
def __init__(self, config):
super(FastKANConvND, self).__init__()

self.structure = structureLoader(config["structures"])
self.base_activation = getAct(config.get("base_activation", "silu"))

self.conv_class = config.get("conv_class", nn.Conv2d)
# self.norm_class = config.get("norm_class", nn.BatchNorm2d) # EDITED

self.kernel_size = config.get("kernel_size", 3)
self.groups = config.get("groups", 1)
self.padding = config.get("padding", (self.kernel_size-1)//2)
self.stride = config.get("stride", 1)
self.dilation = config.get("dilation", 1)
self.ndim = config.get("ndim", 2)
self.grid_size = config.get("grid_size", 8)
self.grid_range = config.get("grid_range", [-2,2])
self.dropout = config.get("dropout", 0.0)

self.layers = nn.ModuleList([FastKANConvNDLayer(conv_class = self.conv_class,
# norm_class = self.norm_class, # EDITED
input_dim = self.structure[i],
output_dim = self.structure[i+1],
kernel_size = self.kernel_size,
groups = self.groups,
padding = self.padding,
stride = self.stride,
dilation = self.dilation,
ndim = self.ndim,
grid_size = self.grid_size,
base_activation = self.base_activation,
grid_range = self.grid_range,
dropout = self.dropout) for i in range(len(self.structure)-1)])

def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
return x
4 changes: 0 additions & 4 deletions src/models/KAN.py

This file was deleted.

Loading