diff --git a/src/stage2/transport/integrators.py b/src/stage2/transport/integrators.py index e55efad..6e97564 100644 --- a/src/stage2/transport/integrators.py +++ b/src/stage2/transport/integrators.py @@ -61,7 +61,7 @@ def __forward_fn(self): try: sampler = sampler_dict[self.sampler_type] - except: + except KeyError: raise NotImplementedError("Smapler type not implemented.") return sampler diff --git a/src/stage2/transport/transport.py b/src/stage2/transport/transport.py index 11922f5..b8476d1 100644 --- a/src/stage2/transport/transport.py +++ b/src/stage2/transport/transport.py @@ -194,7 +194,7 @@ def training_losses( - x1: datapoint - model_kwargs: additional arguments for the model """ - if model_kwargs == None: + if model_kwargs is None: model_kwargs = {} t, x0, x1 = self.sample(x1) diff --git a/src/train.py b/src/train.py index 5c3132e..902bccb 100644 --- a/src/train.py +++ b/src/train.py @@ -220,11 +220,11 @@ def guidance_value(key: str, default: float) -> float: if args.compile: try: rae.encode = torch.compile(rae.encode) - except: + except Exception: print('RAE ENCODE compile meets error, falling back to no compile') try: model.forward = torch.compile(model.forward) - except: + except Exception: print('MODEL FORWARD compile meets error, falling back to no compile') else: raise NotImplementedError('ARGS>COMPILE')