diff --git a/backtesting/_plotting.py b/backtesting/_plotting.py index 338454da..32cbf5cd 100644 --- a/backtesting/_plotting.py +++ b/backtesting/_plotting.py @@ -162,27 +162,35 @@ def try_mean_first(indicator): equity_data = equity_data.resample(freq, label='right').agg(_EQUITY_AGG).dropna(how='all') assert equity_data.index.equals(df.index) - def _weighted_returns(s, trades=trades): - df = trades.loc[s.index] - return ((df['Size'].abs() * df['ReturnPct']) / df['Size'].abs().sum()).sum() - - def _group_trades(column): - def f(s, new_index=pd.Index(df.index.astype(np.int64)), bars=trades[column]): - if s.size: - # Via int64 because on pandas recently broken datetime - mean_time = int(bars.loc[s.index].astype(np.int64).mean()) - new_bar_idx = new_index.get_indexer([mean_time], method='nearest')[0] - return new_bar_idx - return f - if len(trades): # Avoid pandas "resampling on Int64 index" error - trades = trades.assign(count=1).resample(freq, on='ExitTime', label='right').agg(dict( + abs_sizes = trades['Size'].abs() + trades = trades.assign( + count=1, + _abs_size=abs_sizes, + _weighted_ret=abs_sizes * trades['ReturnPct'], + ) + trades = trades.resample(freq, on='ExitTime', label='right').agg(dict( TRADES_AGG, - ReturnPct=_weighted_returns, count='sum', - EntryBar=_group_trades('EntryTime'), - ExitBar=_group_trades('ExitTime'), + _abs_size='sum', + _weighted_ret='sum', + EntryTime='mean', + ExitTime='mean', )).dropna() + trades['ReturnPct'] = ( + trades['_weighted_ret'] / trades['_abs_size'] + ) + trades.drop(columns=['_abs_size', '_weighted_ret'], inplace=True) + + new_index = pd.Index(df.index.astype(np.int64)) + trades['EntryBar'] = new_index.get_indexer( + trades['EntryTime'].astype(np.int64), + method='nearest', + ) + trades['ExitBar'] = new_index.get_indexer( + trades['ExitTime'].astype(np.int64), + method='nearest', + ) return df, indicators, equity_data, trades diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index 63045ce1..4883f0ae 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -787,6 +787,70 @@ def init(self): # Give browser time to open before tempfile is removed time.sleep(1) + def test_resample_trades_vectorized(self): + """Vectorized trade resampling produces correct weighted returns and bar indices.""" + import backtesting._plotting as _plotting + from backtesting.lib import OHLCV_AGG, TRADES_AGG, _EQUITY_AGG + + bt = Backtest(GOOG, SmaCross) + results = bt.run() + trades = results['_trades'] + if trades.empty: + return + + df_ohlcv = bt._data.copy() + equity_data = results['_equity_curve'].copy(deep=False) + + freq = '1ME' + df_resampled = df_ohlcv.resample( + freq, label='right').agg(OHLCV_AGG).dropna() + + # --- Reference (original callback) implementation --- + def _weighted_returns(s, _trades=trades): + d = _trades.loc[s.index] + return ( + (d['Size'].abs() * d['ReturnPct']) + / d['Size'].abs().sum() + ).sum() + + def _group_trades(column): + def f(s, + new_index=pd.Index( + df_resampled.index.astype(np.int64)), + bars=trades[column]): + if s.size: + mean_time = int( + bars.loc[s.index].astype(np.int64).mean() + ) + return new_index.get_indexer( + [mean_time], method='nearest')[0] + return f + + ref = trades.assign(count=1).resample( + freq, on='ExitTime', label='right', + ).agg(dict( + TRADES_AGG, + ReturnPct=_weighted_returns, + count='sum', + EntryBar=_group_trades('EntryTime'), + ExitBar=_group_trades('ExitTime'), + )).dropna() + + # --- Vectorized implementation --- + eq = equity_data.resample( + freq, label='right').agg(_EQUITY_AGG).dropna(how='all') + _, _, _, vec_trades = _plotting._maybe_resample_data( + freq, df_ohlcv.copy(), [], eq, trades.copy(), + ) + + cols = ['Size', 'EntryBar', 'ExitBar', 'EntryPrice', + 'ExitPrice', 'PnL', 'ReturnPct', 'count'] + assert_frame_equal( + vec_trades[cols].reset_index(drop=True), + ref[cols].reset_index(drop=True), + check_exact=False, check_dtype=False, atol=1e-10, + ) + def test_indicator_name(self): test_self = self