Skip to content

Commit 9a4d3db

Browse files
committed
some layers can now export to normal nn.Module for serialization
1 parent 5ad1c54 commit 9a4d3db

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

diffabs/deeppoly.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,19 @@ class Linear(nn.Linear):
445445
def __str__(self):
446446
return f'{Dom.name}.' + super().__str__()
447447

448+
@classmethod
449+
def from_module(cls, src: nn.Linear) -> Linear:
450+
with_bias = src.bias is not None
451+
new_lin = Linear(src.in_features, src.out_features, with_bias)
452+
new_lin.load_state_dict(src.state_dict())
453+
return new_lin
454+
455+
def export(self) -> nn.Linear:
456+
with_bias = self.bias is not None
457+
lin = nn.Linear(self.in_features, self.out_features, with_bias)
458+
lin.load_state_dict(self.state_dict())
459+
return lin
460+
448461
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
449462
"""
450463
:param ts: either Tensor, Ele, or Ele tensors
@@ -814,6 +827,9 @@ class ReLU(nn.ReLU):
814827
def __str__(self):
815828
return f'{Dom.name}.' + super().__str__()
816829

830+
def export(self) -> nn.ReLU:
831+
return nn.ReLU()
832+
817833
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
818834
""" According to paper, it approximates E by either of the two cases, whichever has smaller areas.
819835
Mathematically, it can be proved that the (linear) approximation is optimal in terms of approximated areas.
@@ -933,6 +949,9 @@ class Tanh(nn.Tanh):
933949
def __str__(self):
934950
return f'{Dom.name}.' + super().__str__()
935951

952+
def export(self) -> nn.Tanh:
953+
return nn.Tanh()
954+
936955
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
937956
""" For both LB' and UB', it chooses the smaller slope between LB-UB and LB'/UB'. Specifically,
938957
when L > 0, LB' chooses LB-UB, otherwise LB';

diffabs/interval.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,19 @@ class Linear(nn.Linear):
271271
def __str__(self):
272272
return f'{Dom.name}.' + super().__str__()
273273

274+
@classmethod
275+
def from_module(cls, src: nn.Linear) -> Linear:
276+
with_bias = src.bias is not None
277+
new_lin = Linear(src.in_features, src.out_features, with_bias)
278+
new_lin.load_state_dict(src.state_dict())
279+
return new_lin
280+
281+
def export(self) -> nn.Linear:
282+
with_bias = self.bias is not None
283+
lin = nn.Linear(self.in_features, self.out_features, with_bias)
284+
lin.load_state_dict(self.state_dict())
285+
return lin
286+
274287
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
275288
""" Re-implement the forward computation by myself, because F.linear() may apply optimization using
276289
torch.addmm() which requires inputs to be tensor.
@@ -416,6 +429,9 @@ class ReLU(nn.ReLU):
416429
def __str__(self):
417430
return f'{Dom.name}.' + super().__str__()
418431

432+
def export(self) -> nn.ReLU:
433+
return nn.ReLU()
434+
419435
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
420436
return _distribute_to_super(super().forward, *ts)
421437
pass
@@ -425,6 +441,9 @@ class Tanh(nn.Tanh):
425441
def __str__(self):
426442
return f'{Dom.name}.' + super().__str__()
427443

444+
def export(self) -> nn.Tanh:
445+
return nn.Tanh()
446+
428447
def forward(self, *ts: Union[Tensor, Ele]) -> Union[Tensor, Ele, Tuple[Tensor, ...]]:
429448
return _distribute_to_super(super().forward, *ts)
430449
pass

0 commit comments

Comments
 (0)