Skip to content

Commit

Permalink
Merge pull request #22 from boettiger-lab/greencrabmonthNorm
Browse files Browse the repository at this point in the history
add monthly normalized env
  • Loading branch information
jiangjingzhi2003 authored Dec 31, 2024
2 parents 47a60cb + 7bd1485 commit 24fe46d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/rl4greencrab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rl4greencrab.envs.green_crab_ipm import greenCrabEnv, greenCrabSimplifiedEnv
from rl4greencrab.envs.time_series import timeSeriesEnv
from rl4greencrab.envs.green_crab_monthly_env import greenCrabMonthEnv
from rl4greencrab.envs.green_crab_monthly_env_norm import greenCrabMonthEnvNormalized
from rl4greencrab.agents.const_action import constAction, constActionNatUnits, multiConstAction
from rl4greencrab.agents.const_escapement import constEsc
from rl4greencrab.utils.simulate import simulator, get_simulator, evaluate_agent
Expand All @@ -21,4 +22,8 @@
register(
id="monthenv",
entry_point="rl4greencrab.envs.green_crab_monthly_env:greenCrabMonthEnv",
)
register(
id="monthenvnorm",
entry_point="rl4greencrab.envs.green_crab_monthly_env_norm:greenCrabMonthEnvNormalized",
)
52 changes: 52 additions & 0 deletions src/rl4greencrab/envs/green_crab_monthly_env_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import gymnasium as gym
import logging
import numpy as np

from gymnasium import spaces
from scipy.stats import norm
from rl4greencrab.envs.green_crab_monthly_env import greenCrabMonthEnv

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)

class greenCrabMonthEnvNormalized(greenCrabMonthEnv):
def __init__(self, config={}):
super().__init__(config=config)
self.observation_space = spaces.Box(
np.array([-1], dtype=np.float32),
np.array([1], dtype=np.float32),
dtype=np.float32,
)
self.action_space = spaces.Box(
np.float32([-1, -1, -1]),
np.float32([1, 1, 1]),
dtype=np.float32,
)
self.max_action = config.get('max_action', 2000) # ad hoc based on previous values
self.cpue_normalization = config.get('cpue_normalization', 100)

def step(self, action):
action_natural_units = np.maximum(self.max_action * (1 + action)/2 , 0.) #convert to normal action
obs, rew, term, trunc, info = super().step(
np.float32(action_natural_units)
)
normalized_cpue = 2 * self.cpue_2(obs, action_natural_units) - 1
# observation = np.float32(np.append(normalized_cpue, action))
observation = normalized_cpue
rew = 10 * rew # use larger rewards, possibly makes trainer easier?
return observation, rew, term, trunc, info

def reset(self, *, seed=42, options=None):
_, info = super().reset(seed=seed, options=options)

# completely new obs
return - np.ones(shape=self.observation_space.shape, dtype=np.float32), info

def cpue_2(self, obs, action_natural_units):
# If you don't set traps, the catch-per-effort is 0/0. Should be NaN, but we call it 0
if np.sum(action_natural_units) <= 0:
return np.float32([0])
# can't tell which traps caught each number of crabs here. Perhaps too simple but maybe realistic
cpue_2 = np.float32([
np.sum(obs[0]) / (self.cpue_normalization * np.sum(action_natural_units)),
])
return cpue_2

0 comments on commit 24fe46d

Please sign in to comment.