Building Custom Environments¶
This guide shows how to extend TorchTrade for custom trading environments.
When to Build a Custom Environment¶
Build a custom environment when you need:
- Custom asset types: Options, forex, commodities
- Complex order types: Market-on-close, iceberg orders
- Custom state: Order book data, sentiment, news
- Specific trading rules: Pattern day trading, portfolio constraints
- New exchange integrations: Unsupported brokers/APIs
Environment Architecture¶
TorchTrade environments inherit from TorchRL's EnvBase:
Required Methods¶
| Method | Purpose | Returns |
|---|---|---|
_reset() |
Initialize episode state | TensorDict with initial observation |
_step(tensordict) |
Execute action, update state | TensorDict with next observation |
_set_seed(seed) |
Set random seed for reproducibility | None |
observation_spec |
Define observation space | Composite |
action_spec |
Define action space | Categorical or Bounded |
reward_spec |
Define reward space | Unbounded |
Example 1: Simple Custom Environment¶
Minimal environment from scratch:
from torchrl.envs import EnvBase
from torchrl.data import Categorical, Composite, Unbounded
from tensordict import TensorDict
import torch
class SimpleCustomEnv(EnvBase):
"""
Minimal custom trading environment.
State: [price, position]
Actions: 0 (HOLD), 1 (BUY), 2 (SELL)
Reward: Log return
"""
def __init__(self, prices: torch.Tensor, **kwargs):
super().__init__(**kwargs)
self.prices = prices
self.current_step = 0
self.position = 0 # 0 or 1
self.entry_price = 0.0
# Define specs
self._observation_spec = Composite({
"price": Unbounded(shape=(1,)),
"position": Unbounded(shape=(1,)),
})
self._action_spec = Categorical(n=3)
self._reward_spec = Unbounded(shape=(1,))
def _reset(self, tensordict=None, **kwargs):
"""Reset to initial state"""
self.current_step = 0
self.position = 0
self.entry_price = 0.0
return TensorDict({
"price": torch.tensor([self.prices[0].item()]),
"position": torch.tensor([0.0]),
}, batch_size=self.batch_size)
def _step(self, tensordict: TensorDict) -> TensorDict:
"""Execute one step"""
action = tensordict["action"].item()
current_price = self.prices[self.current_step].item()
# Execute action
reward = 0.0
if action == 1 and self.position == 0: # BUY
self.position = 1
self.entry_price = current_price
elif action == 2 and self.position == 1: # SELL
reward = (current_price - self.entry_price) / self.entry_price
self.position = 0
self.entry_price = 0.0
# Move to next step
self.current_step += 1
done = self.current_step >= len(self.prices) - 1
# Build output tensordict
next_price = self.prices[self.current_step].item() if not done else current_price
return TensorDict({
"price": torch.tensor([next_price]),
"position": torch.tensor([float(self.position)]),
"reward": torch.tensor([reward]),
"done": torch.tensor([done]),
}, batch_size=self.batch_size)
def _set_seed(self, seed: int):
"""Set random seed"""
torch.manual_seed(seed)
# Usage
prices = torch.randn(1000).cumsum(0) + 100 # Random walk prices
env = SimpleCustomEnv(prices, batch_size=[])
obs = env.reset()
for _ in range(100):
action = env.action_spec.rand() # Random action
obs = env.step(action)
if obs["done"]:
break
Example 2: Extending Existing Environments¶
Extend SequentialTradingEnv to add custom features:
from torchtrade.envs.offline import SequentialTradingEnv, SequentialTradingEnvConfig
from tensordict import TensorDict
import torch
class CustomLongOnlyEnv(SequentialTradingEnv):
"""
Extended SequentialTradingEnv with sentiment data.
"""
def __init__(self, df, config: SequentialTradingEnvConfig, sentiment_data: torch.Tensor):
super().__init__(df, config)
self.sentiment_data = sentiment_data # Timeseries sentiment scores
# Extend observation spec
from torchrl.data import Unbounded
self._observation_spec["sentiment"] = Unbounded(shape=(1,))
def _reset(self, tensordict=None, **kwargs):
"""Add sentiment to observations"""
obs = super()._reset(tensordict, **kwargs)
# Add current sentiment
sentiment_idx = self.sampler.reset_index
obs["sentiment"] = torch.tensor([self.sentiment_data[sentiment_idx].item()])
return obs
def _step(self, tensordict: TensorDict) -> TensorDict:
"""Add sentiment to step observations"""
obs = super()._step(tensordict)
# Add current sentiment
sentiment_idx = self.sampler.current_index
obs["sentiment"] = torch.tensor([self.sentiment_data[sentiment_idx].item()])
return obs
# Usage
import pandas as pd
df = pd.read_csv("prices.csv")
sentiment = torch.randn(len(df)) # Random sentiment scores
config = SequentialTradingEnvConfig(
time_frames=["1min", "5min"],
window_sizes=[12, 8],
execute_on=(5, "Minute"),
)
env = CustomLongOnlyEnv(df, config, sentiment)
# Policy network sees sentiment in observations
obs = env.reset()
print(obs.keys()) # [..., 'sentiment']
Design Patterns¶
1. Composition Over Inheritance¶
Prefer composing existing components:
class CustomEnv(SequentialTradingEnv):
def __init__(self, df, config, custom_component):
super().__init__(df, config)
self.custom_component = custom_component # Inject custom logic
def _step(self, tensordict):
obs = super()._step(tensordict)
# Modify obs with custom_component
obs["custom_feature"] = self.custom_component.compute(obs)
return obs
2. Observation Spec Extension¶
Always update observation specs when adding new fields:
3. State Management¶
TorchTrade provides structured state management classes for consistent state handling across all environments. See State Management for full details.
Quick Reference - PositionState:
The PositionState dataclass encapsulates position-related state in a structured way:
from torchtrade.envs.core.state import PositionState
class CustomEnv(EnvBase):
def __init__(self):
super().__init__()
# Use PositionState to group position-related variables
self.position = PositionState()
# Provides: position.current_position, position.position_size,
# position.position_value, position.entry_price,
# position.unrealized_pnlpc, position.hold_counter
# Other environment state
self.current_step = 0
self.cash = 1000.0
self.current_timestamp = None
def _reset(self, tensordict=None, **kwargs):
"""Reset all state including position"""
self.position.reset() # Resets all position fields to defaults
self.current_step = 0
self.cash = 1000.0
# ... reset other state
def _step(self, tensordict):
"""Use position state in step logic"""
if action == "BUY" and self.position.position_size == 0:
self.position.position_size = 100
self.position.entry_price = current_price
self.position.current_position = 1.0
# ... rest of step logic
Benefits of PositionState:
- Groups related state variables together (better organization)
- Provides a single .reset() method for all position fields
- Makes position state explicit and easier to track
- Used consistently across all TorchTrade environments
Source: torchtrade/envs/core/state.py
Quick Reference - HistoryTracker:
Use HistoryTracker to record episode data for analysis and visualization:
from torchtrade.envs.core.state import HistoryTracker
class CustomEnv(EnvBase):
def __init__(self):
super().__init__()
self.history = HistoryTracker()
def _step(self, tensordict):
# ... execute step logic ...
self.history.record_step(
price=current_price,
action=action.item(),
reward=reward,
portfolio_value=self.cash + self.position.position_value,
position=self.position.position_size,
action_type="long", # or "short", "hold", "close", etc.
)
def _reset(self, tensordict=None, **kwargs):
self.history.reset() # Clear history at episode start
# ... rest of reset logic
Access history via history.to_dict() — returns prices, actions, rewards, portfolio values, positions, and action types.
Source: torchtrade/envs/core/state.py
Testing Custom Environments¶
1. Spec Compliance¶
Verify specs match actual outputs:
env = CustomEnv(...)
# Check reset
obs = env.reset()
assert env.observation_spec.is_in(obs), "Reset observation doesn't match spec"
# Check step
action = env.action_spec.rand()
obs = env.step(action)
assert env.observation_spec.is_in(obs), "Step observation doesn't match spec"
assert env.reward_spec.is_in(obs["reward"]), "Reward doesn't match spec"
2. Episode Completion¶
Ensure episodes terminate correctly:
env = CustomEnv(...)
obs = env.reset()
for i in range(10000): # Safety limit
action = env.action_spec.rand()
obs = env.step(action)
if obs["done"]:
print(f"Episode ended at step {i}")
break
else:
raise AssertionError("Episode never ended!")
3. Reward Sanity¶
Check reward values are reasonable:
rewards = []
for episode in range(100):
obs = env.reset()
episode_reward = 0
while not obs["done"]:
action = env.action_spec.rand()
obs = env.step(action)
episode_reward += obs["reward"].item()
rewards.append(episode_reward)
print(f"Mean reward: {sum(rewards)/len(rewards):.2f}")
print(f"Reward range: [{min(rewards):.2f}, {max(rewards):.2f}]")
Common Pitfalls¶
| Issue | Problem | Solution |
|---|---|---|
| Spec mismatch | Observation shape != spec shape | Update _observation_spec in __init__ |
| Forgotten batch_size | TensorDict missing batch_size | Always pass batch_size=self.batch_size |
| Missing done signal | Episode never ends | Set done=True in terminal state |
| Mutable state | State persists across episodes | Reset ALL state variables in _reset() |
| Incorrect device | Tensors on wrong device | Use self.device for all tensors |
Next Steps¶
- Offline Environments - Existing environment architecture
- TorchRL EnvBase - Base class documentation