diff --git a/TIMESNET_REPRODUCTION_GUIDE.md b/TIMESNET_REPRODUCTION_GUIDE.md new file mode 100644 index 000000000..ae27da2dc --- /dev/null +++ b/TIMESNET_REPRODUCTION_GUIDE.md @@ -0,0 +1,416 @@ +# TimesNet 论文复现指南 + +这份文档对应仓库里的 `article/TimesNet.pdf`,目标是把仓库现有的 TimesNet 实验脚本真正跑通,并把结果按论文口径汇总出来,方便你判断“有没有复现到论文水平”。 + +## 1. 先明确复现目标 + +`TimesNet.pdf` 不是只做了一个任务,而是覆盖了 5 类任务: + +- 长期预测 `long_term_forecast` +- 短期预测 `short_term_forecast` +- 缺失值填补 `imputation` +- 分类 `classification` +- 异常检测 `anomaly_detection` + +所以“复现 TimesNet 论文”通常不是跑一个命令,而是按这 5 类任务分别跑对应脚本,再把结果和论文表格对齐。 + +## 2. 你现在应该按什么顺序做 + +推荐顺序如下: + +1. 先确认环境和数据都正常。 +2. 先做一个很小的冒烟测试,确保代码、数据、GPU、依赖都通。 +3. 先跑长期预测,因为这是最标准、最容易判断是否跑通的主线任务。 +4. 再跑短期预测、缺失值填补、分类、异常检测。 +5. 最后统一汇总结果,和论文表格做对照。 + +这样做的原因很简单: + +- 如果你一上来直接跑全量脚本,一旦环境、GPU、数据路径有问题,会浪费很长时间。 +- 长期预测的输出最直观,能最快验证训练和测试流程是否正常。 +- 论文主表里很多结果是“平均值”,不是单次运行日志里的某一个值,所以最后必须做一次汇总。 + +## 3. 这次我给你补了什么 + +为了避免你直接跑原始脚本时踩坑,我新增了两个工具: + +- `scripts/reproduce_timesnet.sh` + 作用:统一调用仓库原始 TimesNet 脚本,并自动绕开原脚本里写死的 `CUDA_VISIBLE_DEVICES=2/4/5/7`。 +- `scripts/summarize_timesnet_results.py` + 作用:把你跑出来的结果自动聚合成一个对照报告 `TIMESNET_REPRODUCTION_RESULTS.md`。 + +为什么要这样做: + +- 这台机器实际只有 `GPU 0` 可见。 +- 原始脚本里很多 `CUDA_VISIBLE_DEVICES` 写成了 `2`、`4`、`5`、`7`,直接跑很容易把 GPU 配错。 +- 分类和异常检测日志里输出的是 `0~1` 的比例值,但论文表格写的是百分数,这一步手工对很容易看错。 + +## 4. 环境和数据先检查什么 + +你已经指定使用 Anaconda 的 `timesnet` 环境,所以默认直接用它。 + +先检查: + +```bash +conda run -n timesnet python --version +conda run -n timesnet python -c "import torch; print(torch.__version__, torch.cuda.is_available())" +nvidia-smi +``` + +再检查数据目录: + +```bash +ls dataset/ETT-small +ls dataset/m4 +ls dataset/electricity +ls dataset/weather +ls dataset/traffic +ls dataset/exchange_rate +ls dataset/illness +ls dataset/PSM +ls dataset/SMD +ls dataset/SMAP +ls dataset/MSL +ls dataset/SWaT +ls dataset/EthanolConcentration +``` + +你这个仓库里这些公开数据已经在 `dataset/` 下,基础条件是满足的。 + +## 5. 第一步不要直接跑全量,先做冒烟测试 + +先用一个 1 epoch 的小命令确认训练入口正常: + +```bash +conda run --no-capture-output -n timesnet python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh1.csv \ + --model_id smoke_ETTh1_96_96 \ + --model TimesNet \ + --data ETTh1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len 96 \ + --e_layers 2 \ + --d_layers 1 \ + --factor 3 \ + --enc_in 7 \ + --dec_in 7 \ + --c_out 7 \ + --d_model 16 \ + --d_ff 32 \ + --top_k 5 \ + --train_epochs 1 \ + --patience 1 \ + --num_workers 0 \ + --gpu 0 \ + --des smoke \ + --itr 1 +``` + +如果这个命令能正常训练并打印出类似 `mse:... mae:...`,说明下面的全量复现可以开始了。 + +## 6. 正式复现时用哪个脚本 + +统一入口: + +```bash +bash scripts/reproduce_timesnet.sh [target] +``` + +支持的任务: + +- `long_term` +- `short_term` +- `imputation` +- `classification` +- `anomaly` +- `all` + +### 6.1 长期预测 + +先跑长期预测最合适。 + +跑全部长期预测数据集: + +```bash +bash scripts/reproduce_timesnet.sh long_term all +``` + +如果你只想先跑一个数据集: + +```bash +bash scripts/reproduce_timesnet.sh long_term etth1 +bash scripts/reproduce_timesnet.sh long_term ettm1 +bash scripts/reproduce_timesnet.sh long_term traffic +``` + +这一步对应论文主文 Table 2,仓库脚本对应论文 Appendix Table 13 的完整结果。 + +### 6.2 短期预测 + +M4 是单独一套流程: + +```bash +bash scripts/reproduce_timesnet.sh short_term +``` + +它会依次跑: + +- Yearly +- Quarterly +- Monthly +- Weekly +- Daily +- Hourly + +全部 6 个频率跑完以后,才能按论文口径计算加权平均 `SMAPE / MASE / OWA`。 + +### 6.3 缺失值填补 + +跑全部: + +```bash +bash scripts/reproduce_timesnet.sh imputation all +``` + +只跑一个: + +```bash +bash scripts/reproduce_timesnet.sh imputation etth1 +bash scripts/reproduce_timesnet.sh imputation weather +``` + +这一步对应论文主文 Table 4,仓库脚本对应论文 Appendix Table 16 的完整结果。 + +### 6.4 分类 + +分类脚本会一次跑完 10 个 UEA 子数据集: + +```bash +bash scripts/reproduce_timesnet.sh classification +``` + +这一步对应论文 Figure 5 和 Appendix Table 17。 + +### 6.5 异常检测 + +跑全部: + +```bash +bash scripts/reproduce_timesnet.sh anomaly all +``` + +只跑一个: + +```bash +bash scripts/reproduce_timesnet.sh anomaly psm +bash scripts/reproduce_timesnet.sh anomaly smd +``` + +注意: + +- `SWaT` 脚本不是只跑一个配置,而是连续试多个配置。 +- 最后和论文比的时候,应该按数据集取最优 F1,而不是只看最后一条日志。 + +## 7. 每一步跑完以后结果会落到哪里 + +主要看这几个目录和文件: + +- `checkpoints/` + 训练好的模型权重。 +- `results/` + 主要任务的 `metrics.npy`、分类结果文件等。 +- `test_results/` + 一些可视化图和测试输出。 +- `m4_results/TimesNet/` + M4 预测结果 CSV。 +- `result_long_term_forecast.txt` + 长期预测日志摘要。 +- `result_imputation.txt` + 缺失值填补日志摘要。 +- `result_anomaly_detection.txt` + 异常检测日志摘要。 + +## 8. 跑完以后怎么和论文对表 + +不要人工一条条翻日志,直接运行汇总脚本: + +```bash +conda run --no-capture-output -n timesnet python scripts/summarize_timesnet_results.py +``` + +运行后会生成: + +```bash +TIMESNET_REPRODUCTION_RESULTS.md +``` + +这个文件会自动做几件事: + +- 把长期预测按 4 个预测长度做平均,再和论文 Table 2 对照。 +- 把缺失值填补按 4 个 mask ratio 做平均,再和论文 Table 4 对照。 +- 从 `m4_results/TimesNet/` 重新计算 `SMAPE / MASE / OWA`,再和论文 Table 3 对照。 +- 把分类任务从 `0~1` 精度转换成百分数,再和论文 Table 17 对照。 +- 把异常检测从 `0~1` F1 转成百分数,并且同一数据集取最优 F1,再和论文 Table 5 对照。 + +## 9. 论文里应该对齐到哪些数 + +下面这些是你最终最应该盯住的主表目标值。 + +来源: + +- 论文 arXiv: https://arxiv.org/abs/2210.02186 +- 可读 HTML 版: https://ar5iv.labs.arxiv.org/html/2210.02186 + +### 9.1 长期预测 Table 2 + +这些数是 4 个预测长度的平均值: + +| Dataset | Paper MSE | Paper MAE | +| --- | --- | --- | +| ETTm1 | 0.400 | 0.406 | +| ETTm2 | 0.291 | 0.333 | +| ETTh1 | 0.458 | 0.450 | +| ETTh2 | 0.414 | 0.427 | +| Electricity | 0.192 | 0.295 | +| Traffic | 0.620 | 0.336 | +| Weather | 0.259 | 0.287 | +| Exchange | 0.416 | 0.443 | +| ILI | 2.139 | 0.931 | + +### 9.2 短期预测 Table 3 + +M4 加权平均目标值: + +| Metric | Paper | +| --- | --- | +| SMAPE | 11.829 | +| MASE | 1.585 | +| OWA | 0.851 | + +### 9.3 缺失值填补 Table 4 + +这些数是 4 个 mask ratio 的平均值: + +| Dataset | Paper MSE | Paper MAE | +| --- | --- | --- | +| ETTm1 | 0.027 | 0.107 | +| ETTm2 | 0.022 | 0.088 | +| ETTh1 | 0.078 | 0.187 | +| ETTh2 | 0.049 | 0.146 | +| Electricity | 0.092 | 0.210 | +| Weather | 0.030 | 0.054 | + +### 9.4 分类 + +论文主文给出的平均准确率目标值: + +| Metric | Paper | +| --- | --- | +| Average Accuracy (%) | 73.6 | + +更细的 10 个子数据集精度,见论文 Appendix Table 17。汇总脚本已经把这些目标值内置进去了。 + +### 9.5 异常检测 Table 5 + +这里应该对照的是仓库当前 Inception 版 TimesNet 的 F1: + +| Dataset | Paper F1 (%) | +| --- | --- | +| SMD | 85.12 | +| MSL | 84.18 | +| SMAP | 70.85 | +| SWAT | 92.10 | +| PSM | 95.21 | +| Average | 85.49 | + +## 10. 你最容易踩的坑 + +### 10.1 直接跑原始脚本 + +原始脚本很多写死了: + +```bash +export CUDA_VISIBLE_DEVICES=2 +export CUDA_VISIBLE_DEVICES=4 +export CUDA_VISIBLE_DEVICES=5 +export CUDA_VISIBLE_DEVICES=7 +``` + +如果你的机器没有这些卡号,脚本就会跑错设备,甚至直接退回 CPU 或失败。 + +所以建议直接使用: + +```bash +bash scripts/reproduce_timesnet.sh ... +``` + +### 10.2 只看一条日志,不做平均 + +论文主表很多不是某一个单次实验值,而是: + +- 长期预测:4 个预测长度平均 +- 缺失值填补:4 个 mask ratio 平均 +- 短期预测:M4 6 个频率加权汇总 +- 分类:10 个 UEA 子集平均 +- 异常检测:5 个数据集平均 + +所以一定要跑汇总脚本,而不是只截某一行日志。 + +### 10.3 分类和异常检测的单位看错 + +仓库日志里通常打印: + +- `accuracy:0.357` +- `F-score : 0.952` + +论文里对应的是: + +- `35.7%` +- `95.2%` + +汇总脚本已经做了自动换算。 + +### 10.4 SWaT 不是单配置 + +`scripts/anomaly_detection/SWAT/TimesNet.sh` 会连续尝试多组超参。 + +所以: + +- 你不能只看最后一条。 +- 应该按同一数据集取最优 F1 再和论文比。 + +## 11. 我建议你的实际执行顺序 + +如果你是第一次复现,最稳妥的执行顺序就是这 6 条: + +```bash +conda run --no-capture-output -n timesnet python -u run.py --task_name long_term_forecast --is_training 1 --root_path ./dataset/ETT-small/ --data_path ETTh1.csv --model_id smoke_ETTh1_96_96 --model TimesNet --data ETTh1 --features M --seq_len 96 --label_len 48 --pred_len 96 --e_layers 2 --d_layers 1 --factor 3 --enc_in 7 --dec_in 7 --c_out 7 --d_model 16 --d_ff 32 --top_k 5 --train_epochs 1 --patience 1 --num_workers 0 --gpu 0 --des smoke --itr 1 + +bash scripts/reproduce_timesnet.sh long_term all +bash scripts/reproduce_timesnet.sh short_term +bash scripts/reproduce_timesnet.sh imputation all +bash scripts/reproduce_timesnet.sh classification +bash scripts/reproduce_timesnet.sh anomaly all + +conda run --no-capture-output -n timesnet python scripts/summarize_timesnet_results.py +``` + +## 12. 最后怎么判断算“复现成功” + +通常按下面的标准判断: + +- 长期预测和缺失值填补的平均 `MSE / MAE` 与论文接近。 +- M4 的 `SMAPE / MASE / OWA` 接近论文。 +- 分类平均准确率接近 `73.6%`。 +- 异常检测平均 F1 接近 `85.49%`。 +- 单个数据集有轻微浮动是正常的,尤其当 CUDA、PyTorch、小数精度、DataLoader worker 数、随机性略有差异时。 + +如果你愿意更严格一点,就看 `TIMESNET_REPRODUCTION_RESULTS.md` 里的 `Delta` 列: + +- 越接近 `0` 越好。 +- 如果某一整个任务普遍偏差较大,优先检查环境、GPU、脚本是否完整跑完、是否误把比例值当成百分数、以及是否把平均值算错了。 diff --git a/TIMESNET_REPRODUCTION_RESULTS.md b/TIMESNET_REPRODUCTION_RESULTS.md new file mode 100644 index 000000000..9f4b1a449 --- /dev/null +++ b/TIMESNET_REPRODUCTION_RESULTS.md @@ -0,0 +1,76 @@ +# TimesNet Reproduction Summary + +This file is generated by `scripts/summarize_timesnet_results.py`. + +## Long-Term Forecasting + +| Dataset | Ours MSE | Paper MSE | Delta | Ours MAE | Paper MAE | Delta | Runs | +| --- | --- | --- | --- | --- | --- | --- | --- | +| ETTm1 | - | 0.400 | - | - | 0.406 | - | - | +| ETTm2 | - | 0.291 | - | - | 0.333 | - | - | +| ETTh1 | - | 0.458 | - | - | 0.450 | - | - | +| ETTh2 | - | 0.414 | - | - | 0.427 | - | - | +| Electricity | - | 0.192 | - | - | 0.295 | - | - | +| Traffic | - | 0.620 | - | - | 0.336 | - | - | +| Weather | - | 0.259 | - | - | 0.287 | - | - | +| Exchange | - | 0.416 | - | - | 0.443 | - | - | +| ILI | - | 2.139 | - | - | 0.931 | - | - | + +## Short-Term Forecasting + +| Metric | Ours | Paper | Delta | +| --- | --- | --- | --- | +| SMAPE | - | 11.829 | - | +| MASE | - | 1.585 | - | +| OWA | - | 0.851 | - | + +## Imputation + +| Dataset | Ours MSE | Paper MSE | Delta | Ours MAE | Paper MAE | Delta | Runs | +| --- | --- | --- | --- | --- | --- | --- | --- | +| ETTm1 | - | 0.027 | - | - | 0.107 | - | - | +| ETTm2 | - | 0.022 | - | - | 0.088 | - | - | +| ETTh1 | - | 0.078 | - | - | 0.187 | - | - | +| ETTh2 | - | 0.049 | - | - | 0.146 | - | - | +| Electricity | - | 0.092 | - | - | 0.210 | - | - | +| Weather | - | 0.030 | - | - | 0.054 | - | - | + +## Classification + +| Dataset | Ours Acc(%) | Paper Acc(%) | Delta | Runs | +| --- | --- | --- | --- | --- | +| EthanolConcentration | - | 35.7 | - | - | +| FaceDetection | - | 68.6 | - | - | +| Handwriting | - | 32.1 | - | - | +| Heartbeat | - | 78.0 | - | - | +| JapaneseVowels | - | 98.4 | - | - | +| PEMS-SF | - | 89.6 | - | - | +| SelfRegulationSCP1 | - | 91.8 | - | - | +| SelfRegulationSCP2 | - | 57.2 | - | - | +| SpokenArabicDigits | - | 99.0 | - | - | +| UWaveGestureLibrary | - | 85.3 | - | - | + +| Dataset | Ours Acc(%) | Paper Acc(%) | Delta | Covered | +| --- | --- | --- | --- | --- | +| Average | - | 73.6 | - | 0 | + +## Anomaly Detection + +| Dataset | Ours F1(%) | Paper F1(%) | Delta | Runs | +| --- | --- | --- | --- | --- | +| SMD | - | 85.12 | - | - | +| MSL | - | 84.18 | - | - | +| SMAP | - | 70.85 | - | - | +| SWAT | - | 92.10 | - | - | +| PSM | - | 95.21 | - | - | + +| Dataset | Ours F1(%) | Paper F1(%) | Delta | Covered | +| --- | --- | --- | --- | --- | +| Average | - | 85.49 | - | 0 | + +## Notes + +- Long-term forecasting and imputation paper values are the averages reported in the main paper tables. +- Classification and anomaly detection values in the repository logs are ratios in `[0, 1]`; this summary converts them to percentages to match the paper. +- For anomaly detection, if a dataset has multiple runs, this summary keeps the best F1 because `SWaT` is searched over several settings in the provided script. +- If a row is `-`, the corresponding experiment has not been completed yet. diff --git a/article/TimesNet.pdf b/article/TimesNet.pdf new file mode 100644 index 000000000..942086067 Binary files /dev/null and b/article/TimesNet.pdf differ diff --git a/reproduction_logs/latest b/reproduction_logs/latest new file mode 120000 index 000000000..87918db8a --- /dev/null +++ b/reproduction_logs/latest @@ -0,0 +1 @@ +/root/zm/Time-Series-Library-meter-main_success/reproduction_logs/timesnet_20260321T083250Z \ No newline at end of file diff --git a/reproduction_logs/timesnet_20260321T083059Z/metadata.txt b/reproduction_logs/timesnet_20260321T083059Z/metadata.txt new file mode 100644 index 000000000..ca5f7859c --- /dev/null +++ b/reproduction_logs/timesnet_20260321T083059Z/metadata.txt @@ -0,0 +1,7 @@ +start_utc=2026-03-21T08:30:59Z +root_dir=/root/zm/Time-Series-Library-meter-main_success +conda_env=timesnet +gpu=0 +hostname=ubuntu2204-cygtest-hgxtest + +[python] diff --git a/reproduction_logs/timesnet_20260321T083250Z/TIMESNET_REPRODUCTION_RESULTS.md b/reproduction_logs/timesnet_20260321T083250Z/TIMESNET_REPRODUCTION_RESULTS.md new file mode 100644 index 000000000..9f4b1a449 --- /dev/null +++ b/reproduction_logs/timesnet_20260321T083250Z/TIMESNET_REPRODUCTION_RESULTS.md @@ -0,0 +1,76 @@ +# TimesNet Reproduction Summary + +This file is generated by `scripts/summarize_timesnet_results.py`. + +## Long-Term Forecasting + +| Dataset | Ours MSE | Paper MSE | Delta | Ours MAE | Paper MAE | Delta | Runs | +| --- | --- | --- | --- | --- | --- | --- | --- | +| ETTm1 | - | 0.400 | - | - | 0.406 | - | - | +| ETTm2 | - | 0.291 | - | - | 0.333 | - | - | +| ETTh1 | - | 0.458 | - | - | 0.450 | - | - | +| ETTh2 | - | 0.414 | - | - | 0.427 | - | - | +| Electricity | - | 0.192 | - | - | 0.295 | - | - | +| Traffic | - | 0.620 | - | - | 0.336 | - | - | +| Weather | - | 0.259 | - | - | 0.287 | - | - | +| Exchange | - | 0.416 | - | - | 0.443 | - | - | +| ILI | - | 2.139 | - | - | 0.931 | - | - | + +## Short-Term Forecasting + +| Metric | Ours | Paper | Delta | +| --- | --- | --- | --- | +| SMAPE | - | 11.829 | - | +| MASE | - | 1.585 | - | +| OWA | - | 0.851 | - | + +## Imputation + +| Dataset | Ours MSE | Paper MSE | Delta | Ours MAE | Paper MAE | Delta | Runs | +| --- | --- | --- | --- | --- | --- | --- | --- | +| ETTm1 | - | 0.027 | - | - | 0.107 | - | - | +| ETTm2 | - | 0.022 | - | - | 0.088 | - | - | +| ETTh1 | - | 0.078 | - | - | 0.187 | - | - | +| ETTh2 | - | 0.049 | - | - | 0.146 | - | - | +| Electricity | - | 0.092 | - | - | 0.210 | - | - | +| Weather | - | 0.030 | - | - | 0.054 | - | - | + +## Classification + +| Dataset | Ours Acc(%) | Paper Acc(%) | Delta | Runs | +| --- | --- | --- | --- | --- | +| EthanolConcentration | - | 35.7 | - | - | +| FaceDetection | - | 68.6 | - | - | +| Handwriting | - | 32.1 | - | - | +| Heartbeat | - | 78.0 | - | - | +| JapaneseVowels | - | 98.4 | - | - | +| PEMS-SF | - | 89.6 | - | - | +| SelfRegulationSCP1 | - | 91.8 | - | - | +| SelfRegulationSCP2 | - | 57.2 | - | - | +| SpokenArabicDigits | - | 99.0 | - | - | +| UWaveGestureLibrary | - | 85.3 | - | - | + +| Dataset | Ours Acc(%) | Paper Acc(%) | Delta | Covered | +| --- | --- | --- | --- | --- | +| Average | - | 73.6 | - | 0 | + +## Anomaly Detection + +| Dataset | Ours F1(%) | Paper F1(%) | Delta | Runs | +| --- | --- | --- | --- | --- | +| SMD | - | 85.12 | - | - | +| MSL | - | 84.18 | - | - | +| SMAP | - | 70.85 | - | - | +| SWAT | - | 92.10 | - | - | +| PSM | - | 95.21 | - | - | + +| Dataset | Ours F1(%) | Paper F1(%) | Delta | Covered | +| --- | --- | --- | --- | --- | +| Average | - | 85.49 | - | 0 | + +## Notes + +- Long-term forecasting and imputation paper values are the averages reported in the main paper tables. +- Classification and anomaly detection values in the repository logs are ratios in `[0, 1]`; this summary converts them to percentages to match the paper. +- For anomaly detection, if a dataset has multiple runs, this summary keeps the best F1 because `SWaT` is searched over several settings in the provided script. +- If a row is `-`, the corresponding experiment has not been completed yet. diff --git a/reproduction_logs/timesnet_20260321T083250Z/metadata.txt b/reproduction_logs/timesnet_20260321T083250Z/metadata.txt new file mode 100644 index 000000000..57ef25856 --- /dev/null +++ b/reproduction_logs/timesnet_20260321T083250Z/metadata.txt @@ -0,0 +1,36 @@ +start_utc=2026-03-21T08:32:48Z +root_dir=/root/zm/Time-Series-Library-meter-main_success +conda_env=timesnet +gpu=0 +hostname=ubuntu2204-cygtest-hgxtest + +[python] +Python 3.10.19 + +[torch] +torch_version= 2.6.0+cu124 +cuda_available= True +device_count= 1 +device0= NVIDIA A800 80GB PCIe + +[nvidia-smi] +Sat Mar 21 08:32:50 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A800 80GB PCIe Off | 00000000:00:1A.0 Off | 0 | +| N/A 55C P0 69W / 300W | 1MiB / 81920MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ diff --git a/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.ps b/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.ps new file mode 100644 index 000000000..7d8fba206 Binary files /dev/null and b/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.ps differ diff --git a/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh b/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh index 6a03e5f50..9191e9b12 100644 --- a/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh +++ b/scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=2 +export CUDA_VISIBLE_DEVICES=0 model_name=TimesNet diff --git a/scripts/reproduce_timesnet.sh b/scripts/reproduce_timesnet.sh new file mode 100755 index 000000000..216a4f0ed --- /dev/null +++ b/scripts/reproduce_timesnet.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +CONDA_ENV="${CONDA_ENV:-timesnet}" +GPU="${GPU:-0}" +TASK="${1:-help}" +TARGET="${2:-all}" + +task="$(printf '%s' "$TASK" | tr '[:upper:]' '[:lower:]')" +target="$(printf '%s' "$TARGET" | tr '[:upper:]' '[:lower:]')" + +usage() { + cat <<'EOF' +Usage: + bash scripts/reproduce_timesnet.sh [target] + +Tasks: + long_term [all|etth1|etth2|ettm1|ettm2|ecl|traffic|weather|exchange|ili] + short_term [all] + imputation [all|etth1|etth2|ettm1|ettm2|ecl|weather] + anomaly [all|smd|msl|smap|swat|psm] + classification [all] + all + +Environment variables: + CONDA_ENV Conda environment name. Default: timesnet + GPU CUDA_VISIBLE_DEVICES value. Default: 0 + +Examples: + bash scripts/reproduce_timesnet.sh long_term etth1 + bash scripts/reproduce_timesnet.sh imputation all + bash scripts/reproduce_timesnet.sh short_term + bash scripts/reproduce_timesnet.sh anomaly psm + bash scripts/reproduce_timesnet.sh classification + bash scripts/reproduce_timesnet.sh all +EOF +} + +run_repo_script() { + local script_path="$1" + if [[ ! -f "$script_path" ]]; then + echo "Script not found: $script_path" >&2 + exit 1 + fi + + echo + echo "==> Running $script_path" + echo " conda env: $CONDA_ENV" + echo " CUDA_VISIBLE_DEVICES: $GPU" + + if [[ "${CONDA_DEFAULT_ENV:-}" == "$CONDA_ENV" ]]; then + { + echo "set -euo pipefail" + grep -v '^export CUDA_VISIBLE_DEVICES=' "$script_path" + } | CUDA_VISIBLE_DEVICES="$GPU" bash -s + else + { + echo "set -euo pipefail" + grep -v '^export CUDA_VISIBLE_DEVICES=' "$script_path" + } | CUDA_VISIBLE_DEVICES="$GPU" conda run --no-capture-output -n "$CONDA_ENV" bash -s + fi +} + +run_long_term() { + case "$target" in + all) + run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh" + run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTh2.sh" + run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTm1.sh" + run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTm2.sh" + run_repo_script "scripts/long_term_forecast/ECL_script/TimesNet.sh" + run_repo_script "scripts/long_term_forecast/Traffic_script/TimesNet.sh" + run_repo_script "scripts/long_term_forecast/Weather_script/TimesNet.sh" + run_repo_script "scripts/long_term_forecast/Exchange_script/TimesNet.sh" + run_repo_script "scripts/long_term_forecast/ILI_script/TimesNet.sh" + ;; + etth1) run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh" ;; + etth2) run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTh2.sh" ;; + ettm1) run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTm1.sh" ;; + ettm2) run_repo_script "scripts/long_term_forecast/ETT_script/TimesNet_ETTm2.sh" ;; + ecl|electricity) run_repo_script "scripts/long_term_forecast/ECL_script/TimesNet.sh" ;; + traffic) run_repo_script "scripts/long_term_forecast/Traffic_script/TimesNet.sh" ;; + weather) run_repo_script "scripts/long_term_forecast/Weather_script/TimesNet.sh" ;; + exchange) run_repo_script "scripts/long_term_forecast/Exchange_script/TimesNet.sh" ;; + ili) run_repo_script "scripts/long_term_forecast/ILI_script/TimesNet.sh" ;; + *) + echo "Unsupported long_term target: $TARGET" >&2 + usage + exit 1 + ;; + esac +} + +run_short_term() { + case "$target" in + all) run_repo_script "scripts/short_term_forecast/TimesNet_M4.sh" ;; + *) + echo "Unsupported short_term target: $TARGET" >&2 + usage + exit 1 + ;; + esac +} + +run_imputation() { + case "$target" in + all) + run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTh1.sh" + run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTh2.sh" + run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTm1.sh" + run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTm2.sh" + run_repo_script "scripts/imputation/ECL_script/TimesNet.sh" + run_repo_script "scripts/imputation/Weather_script/TimesNet.sh" + ;; + etth1) run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTh1.sh" ;; + etth2) run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTh2.sh" ;; + ettm1) run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTm1.sh" ;; + ettm2) run_repo_script "scripts/imputation/ETT_script/TimesNet_ETTm2.sh" ;; + ecl|electricity) run_repo_script "scripts/imputation/ECL_script/TimesNet.sh" ;; + weather) run_repo_script "scripts/imputation/Weather_script/TimesNet.sh" ;; + *) + echo "Unsupported imputation target: $TARGET" >&2 + usage + exit 1 + ;; + esac +} + +run_anomaly() { + case "$target" in + all) + run_repo_script "scripts/anomaly_detection/SMD/TimesNet.sh" + run_repo_script "scripts/anomaly_detection/MSL/TimesNet.sh" + run_repo_script "scripts/anomaly_detection/SMAP/TimesNet.sh" + run_repo_script "scripts/anomaly_detection/SWAT/TimesNet.sh" + run_repo_script "scripts/anomaly_detection/PSM/TimesNet.sh" + ;; + smd) run_repo_script "scripts/anomaly_detection/SMD/TimesNet.sh" ;; + msl) run_repo_script "scripts/anomaly_detection/MSL/TimesNet.sh" ;; + smap) run_repo_script "scripts/anomaly_detection/SMAP/TimesNet.sh" ;; + swat) run_repo_script "scripts/anomaly_detection/SWAT/TimesNet.sh" ;; + psm) run_repo_script "scripts/anomaly_detection/PSM/TimesNet.sh" ;; + *) + echo "Unsupported anomaly target: $TARGET" >&2 + usage + exit 1 + ;; + esac +} + +run_classification() { + case "$target" in + all) run_repo_script "scripts/classification/TimesNet.sh" ;; + *) + echo "Unsupported classification target: $TARGET" >&2 + usage + exit 1 + ;; + esac +} + +case "$task" in + long_term|long-term) + run_long_term + ;; + short_term|short-term) + run_short_term + ;; + imputation) + run_imputation + ;; + anomaly|anomaly_detection|anomaly-detection) + run_anomaly + ;; + classification) + run_classification + ;; + all) + target="all" + run_long_term + run_short_term + run_imputation + run_classification + run_anomaly + ;; + help|-h|--help) + usage + ;; + *) + echo "Unsupported task: $TASK" >&2 + usage + exit 1 + ;; +esac diff --git a/scripts/run_timesnet_paper_all.sh b/scripts/run_timesnet_paper_all.sh new file mode 100755 index 000000000..5dc011f83 --- /dev/null +++ b/scripts/run_timesnet_paper_all.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +CONDA_ENV="${CONDA_ENV:-timesnet}" +GPU="${GPU:-0}" +LOG_DIR="${1:-$ROOT_DIR/reproduction_logs/timesnet_$(date -u +%Y%m%dT%H%M%SZ)}" +CONDA_SETUP="${CONDA_SETUP:-/root/miniconda3/etc/profile.d/conda.sh}" + +mkdir -p "$LOG_DIR" +mkdir -p "$ROOT_DIR/reproduction_logs" +ln -sfn "$LOG_DIR" "$ROOT_DIR/reproduction_logs/latest" + +if [[ ! -f "$CONDA_SETUP" ]]; then + echo "Conda setup script not found: $CONDA_SETUP" >&2 + exit 1 +fi + +source "$CONDA_SETUP" +conda activate "$CONDA_ENV" + +write_metadata() { + { + echo "start_utc=$(date -u +%Y-%m-%dT%H:%M:%SZ)" + echo "root_dir=$ROOT_DIR" + echo "conda_env=$CONDA_ENV" + echo "gpu=$GPU" + echo "hostname=$(hostname)" + echo + echo "[python]" + python --version + echo + echo "[torch]" + python - <<'PY' +import torch +print('torch_version=', torch.__version__) +print('cuda_available=', torch.cuda.is_available()) +print('device_count=', torch.cuda.device_count()) +if torch.cuda.is_available(): + print('device0=', torch.cuda.get_device_name(0)) +PY + echo + echo "[nvidia-smi]" + nvidia-smi + } > "$LOG_DIR/metadata.txt" +} + +refresh_summary() { + python scripts/summarize_timesnet_results.py \ + --output "$LOG_DIR/TIMESNET_REPRODUCTION_RESULTS.md" + python scripts/summarize_timesnet_results.py \ + --output "$ROOT_DIR/TIMESNET_REPRODUCTION_RESULTS.md" +} + +run_step() { + local step_name="$1" + shift + local step_log="$LOG_DIR/${step_name}.log" + + { + echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] START $step_name" + "$@" + echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] END $step_name" + } 2>&1 | tee "$step_log" + + refresh_summary | tee -a "$step_log" +} + +write_metadata +refresh_summary > "$LOG_DIR/summary_refresh.log" 2>&1 || true + +run_step 01_long_term env CONDA_ENV="$CONDA_ENV" GPU="$GPU" bash scripts/reproduce_timesnet.sh long_term all +run_step 02_short_term env CONDA_ENV="$CONDA_ENV" GPU="$GPU" bash scripts/reproduce_timesnet.sh short_term +run_step 03_imputation env CONDA_ENV="$CONDA_ENV" GPU="$GPU" bash scripts/reproduce_timesnet.sh imputation all +run_step 04_classification env CONDA_ENV="$CONDA_ENV" GPU="$GPU" bash scripts/reproduce_timesnet.sh classification +run_step 05_anomaly env CONDA_ENV="$CONDA_ENV" GPU="$GPU" bash scripts/reproduce_timesnet.sh anomaly all + +refresh_summary > "$LOG_DIR/summary_refresh.log" 2>&1 +date -u +%Y-%m-%dT%H:%M:%SZ > "$LOG_DIR/completed_at.txt" diff --git a/scripts/summarize_timesnet_results.py b/scripts/summarize_timesnet_results.py new file mode 100755 index 000000000..ab53777ed --- /dev/null +++ b/scripts/summarize_timesnet_results.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import re +import sys +from collections import defaultdict +from pathlib import Path + +import numpy as np + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from utils.m4_summary import M4Summary + +LONG_TERM_TARGETS = { + "ETTm1": {"mse": 0.400, "mae": 0.406}, + "ETTm2": {"mse": 0.291, "mae": 0.333}, + "ETTh1": {"mse": 0.458, "mae": 0.450}, + "ETTh2": {"mse": 0.414, "mae": 0.427}, + "Electricity": {"mse": 0.192, "mae": 0.295}, + "Traffic": {"mse": 0.620, "mae": 0.336}, + "Weather": {"mse": 0.259, "mae": 0.287}, + "Exchange": {"mse": 0.416, "mae": 0.443}, + "ILI": {"mse": 2.139, "mae": 0.931}, +} + +IMPUTATION_TARGETS = { + "ETTm1": {"mse": 0.027, "mae": 0.107}, + "ETTm2": {"mse": 0.022, "mae": 0.088}, + "ETTh1": {"mse": 0.078, "mae": 0.187}, + "ETTh2": {"mse": 0.049, "mae": 0.146}, + "Electricity": {"mse": 0.092, "mae": 0.210}, + "Weather": {"mse": 0.030, "mae": 0.054}, +} + +SHORT_TERM_TARGETS = { + "SMAPE": 11.829, + "MASE": 1.585, + "OWA": 0.851, +} + +CLASSIFICATION_TARGETS = { + "EthanolConcentration": 35.7, + "FaceDetection": 68.6, + "Handwriting": 32.1, + "Heartbeat": 78.0, + "JapaneseVowels": 98.4, + "PEMS-SF": 89.6, + "SelfRegulationSCP1": 91.8, + "SelfRegulationSCP2": 57.2, + "SpokenArabicDigits": 99.0, + "UWaveGestureLibrary": 85.3, + "Average": 73.6, +} + +ANOMALY_TARGETS = { + "SMD": 85.12, + "MSL": 84.18, + "SMAP": 70.85, + "SWAT": 92.10, + "PSM": 95.21, + "Average": 85.49, +} + +LONG_TERM_ALIASES = { + "ettm1": "ETTm1", + "ettm2": "ETTm2", + "etth1": "ETTh1", + "etth2": "ETTh2", + "ecl": "Electricity", + "electricity": "Electricity", + "traffic": "Traffic", + "weather": "Weather", + "exchange": "Exchange", + "ili": "ILI", +} + +IMPUTATION_ALIASES = { + "ettm1": "ETTm1", + "ettm2": "ETTm2", + "etth1": "ETTh1", + "etth2": "ETTh2", + "ecl": "Electricity", + "electricity": "Electricity", + "weather": "Weather", +} + +ANOMALY_ALIASES = { + "smd": "SMD", + "msl": "MSL", + "smap": "SMAP", + "swat": "SWAT", + "psm": "PSM", +} + +LONG_TERM_RE = re.compile(r"^long_term_forecast_(?P.+?)_TimesNet_(?P[^_]+)_ft") +IMPUTATION_RE = re.compile(r"^imputation_(?P.+?)_TimesNet_(?P[^_]+)_ft") +CLASSIFICATION_RE = re.compile(r"^classification_(?P.+?)_TimesNet_(?P[^_]+)_ft") +ANOMALY_RE = re.compile(r"^anomaly_detection_(?P.+?)_TimesNet_(?P[^_]+)_ft") + + +def fmt_float(value: float, digits: int = 3) -> str: + return f"{value:.{digits}f}" + + +def fmt_delta(value: float, digits: int = 3) -> str: + return f"{value:+.{digits}f}" + + +def add_table(lines: list[str], headers: list[str], rows: list[list[str]]) -> None: + lines.append("| " + " | ".join(headers) + " |") + lines.append("| " + " | ".join(["---"] * len(headers)) + " |") + for row in rows: + lines.append("| " + " | ".join(row) + " |") + lines.append("") + + +def should_skip(name: str) -> bool: + lowered = name.lower() + return "smoke" in lowered + + +def normalize_from_model_id(model_id: str, mapping: dict[str, str]) -> str | None: + prefix = model_id.split("_")[0].lower() + return mapping.get(prefix) + + +def summarize_long_term(results_dir: Path) -> dict[str, dict[str, float | int]]: + grouped: dict[str, list[tuple[float, float]]] = defaultdict(list) + for metrics_path in results_dir.glob("*/metrics.npy"): + setting = metrics_path.parent.name + if should_skip(setting): + continue + match = LONG_TERM_RE.match(setting) + if not match: + continue + dataset = normalize_from_model_id(match.group("model_id"), LONG_TERM_ALIASES) + if dataset is None: + continue + metrics = np.load(metrics_path) + mae = float(metrics[0]) + mse = float(metrics[1]) + grouped[dataset].append((mse, mae)) + + summary: dict[str, dict[str, float | int]] = {} + for dataset, values in grouped.items(): + summary[dataset] = { + "count": len(values), + "mse": float(np.mean([item[0] for item in values])), + "mae": float(np.mean([item[1] for item in values])), + } + return summary + + +def summarize_imputation(results_dir: Path) -> dict[str, dict[str, float | int]]: + grouped: dict[str, list[tuple[float, float]]] = defaultdict(list) + for metrics_path in results_dir.glob("*/metrics.npy"): + setting = metrics_path.parent.name + if should_skip(setting): + continue + match = IMPUTATION_RE.match(setting) + if not match: + continue + dataset = normalize_from_model_id(match.group("model_id"), IMPUTATION_ALIASES) + if dataset is None: + continue + metrics = np.load(metrics_path) + mae = float(metrics[0]) + mse = float(metrics[1]) + grouped[dataset].append((mse, mae)) + + summary: dict[str, dict[str, float | int]] = {} + for dataset, values in grouped.items(): + summary[dataset] = { + "count": len(values), + "mse": float(np.mean([item[0] for item in values])), + "mae": float(np.mean([item[1] for item in values])), + } + return summary + + +def summarize_classification(results_dir: Path) -> dict[str, dict[str, float | int]]: + grouped: dict[str, list[float]] = defaultdict(list) + for result_file in results_dir.glob("*/result_classification.txt"): + setting = result_file.parent.name + if should_skip(setting): + continue + match = CLASSIFICATION_RE.match(setting) + if not match: + continue + dataset = match.group("model_id") + text = result_file.read_text() + metric_match = re.search(r"accuracy:([0-9.]+)", text) + if not metric_match: + continue + accuracy = float(metric_match.group(1)) * 100.0 + grouped[dataset].append(accuracy) + + summary: dict[str, dict[str, float | int]] = {} + for dataset, values in grouped.items(): + summary[dataset] = { + "count": len(values), + "accuracy": float(max(values)), + } + return summary + + +def summarize_anomaly(root_dir: Path) -> dict[str, dict[str, float | int]]: + result_file = root_dir / "result_anomaly_detection.txt" + grouped: dict[str, list[float]] = defaultdict(list) + if not result_file.exists(): + return {} + + current_setting: str | None = None + for raw_line in result_file.read_text().splitlines(): + line = raw_line.strip() + if not line: + continue + if line.startswith("anomaly_detection_"): + current_setting = line + continue + if current_setting is None: + continue + if "F-score" not in line: + continue + + if should_skip(current_setting): + current_setting = None + continue + + match = ANOMALY_RE.match(current_setting) + score_match = re.search(r"F-score : ([0-9.]+)", line) + if not match or not score_match: + current_setting = None + continue + + dataset = normalize_from_model_id(match.group("model_id"), ANOMALY_ALIASES) + if dataset is None: + current_setting = None + continue + + f1 = float(score_match.group(1)) * 100.0 + grouped[dataset].append(f1) + current_setting = None + + summary: dict[str, dict[str, float | int]] = {} + for dataset, values in grouped.items(): + summary[dataset] = { + "count": len(values), + "f1": float(max(values)), + } + return summary + + +def summarize_short_term(root_dir: Path) -> dict[str, float]: + m4_dir = root_dir / "m4_results" / "TimesNet" + required_files = [ + "Yearly_forecast.csv", + "Quarterly_forecast.csv", + "Monthly_forecast.csv", + "Weekly_forecast.csv", + "Daily_forecast.csv", + "Hourly_forecast.csv", + ] + if not all((m4_dir / name).exists() for name in required_files): + return {} + + summary = M4Summary(str(m4_dir) + "/", str(root_dir / "dataset" / "m4")) + smape, owa, _, mase = summary.evaluate() + return { + "SMAPE": float(smape["Average"]), + "MASE": float(mase["Average"]), + "OWA": float(owa["Average"]), + } + + +def build_markdown( + long_term: dict[str, dict[str, float | int]], + short_term: dict[str, float], + imputation: dict[str, dict[str, float | int]], + classification: dict[str, dict[str, float | int]], + anomaly: dict[str, dict[str, float | int]], +) -> str: + lines: list[str] = [] + lines.append("# TimesNet Reproduction Summary") + lines.append("") + lines.append("This file is generated by `scripts/summarize_timesnet_results.py`.") + lines.append("") + + lines.append("## Long-Term Forecasting") + lines.append("") + rows: list[list[str]] = [] + for dataset, target in LONG_TERM_TARGETS.items(): + ours = long_term.get(dataset) + if ours is None: + rows.append([dataset, "-", fmt_float(target["mse"]), "-", "-", fmt_float(target["mae"]), "-", "-"]) + continue + rows.append( + [ + dataset, + fmt_float(float(ours["mse"])), + fmt_float(target["mse"]), + fmt_delta(float(ours["mse"]) - target["mse"]), + fmt_float(float(ours["mae"])), + fmt_float(target["mae"]), + fmt_delta(float(ours["mae"]) - target["mae"]), + str(int(ours["count"])), + ] + ) + add_table(lines, ["Dataset", "Ours MSE", "Paper MSE", "Delta", "Ours MAE", "Paper MAE", "Delta", "Runs"], rows) + + lines.append("## Short-Term Forecasting") + lines.append("") + rows = [] + for metric_name, target in SHORT_TERM_TARGETS.items(): + ours = short_term.get(metric_name) + if ours is None: + rows.append([metric_name, "-", fmt_float(target), "-"]) + continue + rows.append([metric_name, fmt_float(ours), fmt_float(target), fmt_delta(ours - target)]) + add_table(lines, ["Metric", "Ours", "Paper", "Delta"], rows) + + lines.append("## Imputation") + lines.append("") + rows = [] + for dataset, target in IMPUTATION_TARGETS.items(): + ours = imputation.get(dataset) + if ours is None: + rows.append([dataset, "-", fmt_float(target["mse"]), "-", "-", fmt_float(target["mae"]), "-", "-"]) + continue + rows.append( + [ + dataset, + fmt_float(float(ours["mse"])), + fmt_float(target["mse"]), + fmt_delta(float(ours["mse"]) - target["mse"]), + fmt_float(float(ours["mae"])), + fmt_float(target["mae"]), + fmt_delta(float(ours["mae"]) - target["mae"]), + str(int(ours["count"])), + ] + ) + add_table(lines, ["Dataset", "Ours MSE", "Paper MSE", "Delta", "Ours MAE", "Paper MAE", "Delta", "Runs"], rows) + + lines.append("## Classification") + lines.append("") + rows = [] + available_accs = [] + for dataset, target in CLASSIFICATION_TARGETS.items(): + if dataset == "Average": + continue + ours = classification.get(dataset) + if ours is None: + rows.append([dataset, "-", fmt_float(target, 1), "-", "-"]) + continue + available_accs.append(float(ours["accuracy"])) + rows.append( + [ + dataset, + fmt_float(float(ours["accuracy"]), 1), + fmt_float(target, 1), + fmt_delta(float(ours["accuracy"]) - target, 1), + str(int(ours["count"])), + ] + ) + add_table(lines, ["Dataset", "Ours Acc(%)", "Paper Acc(%)", "Delta", "Runs"], rows) + + average_row = [] + if available_accs: + average_accuracy = float(np.mean(available_accs)) + average_row.append( + [ + "Average", + fmt_float(average_accuracy, 1), + fmt_float(CLASSIFICATION_TARGETS["Average"], 1), + fmt_delta(average_accuracy - CLASSIFICATION_TARGETS["Average"], 1), + str(len(available_accs)), + ] + ) + else: + average_row.append(["Average", "-", fmt_float(CLASSIFICATION_TARGETS["Average"], 1), "-", "0"]) + add_table(lines, ["Dataset", "Ours Acc(%)", "Paper Acc(%)", "Delta", "Covered"], average_row) + + lines.append("## Anomaly Detection") + lines.append("") + rows = [] + available_f1 = [] + for dataset, target in ANOMALY_TARGETS.items(): + if dataset == "Average": + continue + ours = anomaly.get(dataset) + if ours is None: + rows.append([dataset, "-", fmt_float(target, 2), "-", "-"]) + continue + available_f1.append(float(ours["f1"])) + rows.append( + [ + dataset, + fmt_float(float(ours["f1"]), 2), + fmt_float(target, 2), + fmt_delta(float(ours["f1"]) - target, 2), + str(int(ours["count"])), + ] + ) + add_table(lines, ["Dataset", "Ours F1(%)", "Paper F1(%)", "Delta", "Runs"], rows) + + average_row = [] + if available_f1: + average_f1 = float(np.mean(available_f1)) + average_row.append( + [ + "Average", + fmt_float(average_f1, 2), + fmt_float(ANOMALY_TARGETS["Average"], 2), + fmt_delta(average_f1 - ANOMALY_TARGETS["Average"], 2), + str(len(available_f1)), + ] + ) + else: + average_row.append(["Average", "-", fmt_float(ANOMALY_TARGETS["Average"], 2), "-", "0"]) + add_table(lines, ["Dataset", "Ours F1(%)", "Paper F1(%)", "Delta", "Covered"], average_row) + + lines.append("## Notes") + lines.append("") + lines.append("- Long-term forecasting and imputation paper values are the averages reported in the main paper tables.") + lines.append("- Classification and anomaly detection values in the repository logs are ratios in `[0, 1]`; this summary converts them to percentages to match the paper.") + lines.append("- For anomaly detection, if a dataset has multiple runs, this summary keeps the best F1 because `SWaT` is searched over several settings in the provided script.") + lines.append("- If a row is `-`, the corresponding experiment has not been completed yet.") + lines.append("") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Summarize TimesNet reproduction results and compare with the paper.") + parser.add_argument("--root", type=Path, default=ROOT, help="Repository root.") + parser.add_argument( + "--output", + type=Path, + default=ROOT / "TIMESNET_REPRODUCTION_RESULTS.md", + help="Output markdown file.", + ) + args = parser.parse_args() + + root_dir = args.root.resolve() + results_dir = root_dir / "results" + + long_term = summarize_long_term(results_dir) + short_term = summarize_short_term(root_dir) + imputation = summarize_imputation(results_dir) + classification = summarize_classification(results_dir) + anomaly = summarize_anomaly(root_dir) + + markdown = build_markdown(long_term, short_term, imputation, classification, anomaly) + args.output.write_text(markdown) + print(f"Wrote summary to {args.output}") + + +if __name__ == "__main__": + main()