Skip to content

Commit 6903ae1

Browse files
committed
fix_: add SetChainUserRpcProviders and SetChainEnabled API
1 parent 0945def commit 6903ae1

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

rpc/network/db/network_db.go

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type NetworksPersistenceInterface interface {
3030
DeleteAllNetworks() error
3131

3232
GetRpcPersistence() RpcProvidersPersistenceInterface
33+
SetEnabled(chainID uint64, enabled bool) error
3334
}
3435

3536
// NetworksPersistence manages networks and their providers.
@@ -255,3 +256,22 @@ func (n *NetworksPersistence) DeleteNetwork(chainID uint64) error {
255256

256257
return nil
257258
}
259+
260+
// SetEnabled updates the enabled status of a network.
261+
func (n *NetworksPersistence) SetEnabled(chainID uint64, enabled bool) error {
262+
q := sq.Update("networks").
263+
Set("enabled", enabled).
264+
Where(sq.Eq{"chain_id": chainID})
265+
266+
query, args, err := q.ToSql()
267+
if err != nil {
268+
return fmt.Errorf("failed to build update query: %w", err)
269+
}
270+
271+
_, err = n.db.Exec(query, args...)
272+
if err != nil {
273+
return fmt.Errorf("failed to execute update query for chain_id %d: %w", chainID, err)
274+
}
275+
276+
return nil
277+
}

rpc/network/db/network_db_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,28 @@ func (s *NetworksPersistenceTestSuite) TestValidationForNetworksAndProviders() {
195195
s.Require().NoError(err)
196196
s.Require().Len(allNetworks, 0, "No invalid networks should be saved")
197197
}
198+
199+
func (s *NetworksPersistenceTestSuite) TestSetEnabled() {
200+
network := testutil.CreateNetwork(api.OptimismChainID, "Optimism Mainnet", DefaultProviders(api.OptimismChainID))
201+
s.addAndVerifyNetworks([]*params.Network{network})
202+
203+
// Disable the network
204+
err := s.networksPersistence.SetEnabled(network.ChainID, false)
205+
s.Require().NoError(err)
206+
207+
// Verify the network is disabled
208+
updatedNetwork, err := s.networksPersistence.GetNetworkByChainID(network.ChainID)
209+
s.Require().NoError(err)
210+
s.Require().Len(updatedNetwork, 1)
211+
s.Require().False(updatedNetwork[0].Enabled)
212+
213+
// Enable the network
214+
err = s.networksPersistence.SetEnabled(network.ChainID, true)
215+
s.Require().NoError(err)
216+
217+
// Verify the network is enabled
218+
updatedNetwork, err = s.networksPersistence.GetNetworkByChainID(network.ChainID)
219+
s.Require().NoError(err)
220+
s.Require().Len(updatedNetwork, 1)
221+
s.Require().True(updatedNetwork[0].Enabled)
222+
}

rpc/network/network.go

+10
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type ManagerInterface interface {
3434
GetTestNetworksEnabled() (bool, error)
3535

3636
SetUserRpcProviders(chainID uint64, providers []params.RpcProvider) error
37+
SetEnabled(chainID uint64, enabled bool) error
3738
}
3839

3940
type Manager struct {
@@ -172,6 +173,15 @@ func (nm *Manager) SetUserRpcProviders(chainID uint64, userProviders []params.Rp
172173
return rpcPersistence.SetRpcProviders(chainID, networkhelper.GetUserProviders(userProviders))
173174
}
174175

176+
// SetEnabled updates the enabled status of a network
177+
func (nm *Manager) SetEnabled(chainID uint64, enabled bool) error {
178+
err := nm.networkPersistence.SetEnabled(chainID, enabled)
179+
if err != nil {
180+
return fmt.Errorf("failed to set enabled status: %w", err)
181+
}
182+
return nil
183+
}
184+
175185
// Find locates a network by ChainID.
176186
func (nm *Manager) Find(chainID uint64) *params.Network {
177187
networks, err := nm.networkPersistence.GetNetworkByChainID(chainID)

services/wallet/api.go

+10
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,16 @@ func (api *API) AddEthereumChain(ctx context.Context, network params.Network) er
400400
return api.s.rpcClient.NetworkManager.Upsert(&network)
401401
}
402402

403+
func (api *API) SetChainUserRpcProviders(ctx context.Context, chainID uint64, rpcProviders []params.RpcProvider) error {
404+
logutils.ZapLogger().Debug("call to SetChainUserRpcProviders")
405+
return api.s.rpcClient.NetworkManager.SetUserRpcProviders(chainID, rpcProviders)
406+
}
407+
408+
func (api *API) SetChainEnabled(ctx context.Context, chainID uint64, enabled bool) error {
409+
logutils.ZapLogger().Debug("call to SetChainEnabled")
410+
return api.s.rpcClient.NetworkManager.SetEnabled(chainID, enabled)
411+
}
412+
403413
func (api *API) DeleteEthereumChain(ctx context.Context, chainID uint64) error {
404414
logutils.ZapLogger().Debug("call to DeleteEthereumChain")
405415
return api.s.rpcClient.NetworkManager.Delete(chainID)

0 commit comments

Comments
 (0)