Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions app/src/api/digitalTwin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { api } from './client';

export type Snapshot = {
snapshot_id: number;
total_expenses: number;
total_income: number;
total_bills: number;
net_worth: number;
created_at: string;
};

export type SimulationResult = {
snapshot_id: number;
total_income: number;
total_expenses: number;
total_bills: number;
net_worth: number;
adjustments: Record<string, number>;
};

export async function createSnapshot(): Promise<Snapshot> {
return api<Snapshot>('/digital-twin/create', { method: 'POST' });
}

export async function listSnapshots(): Promise<Snapshot[]> {
return api<Snapshot[]>('/digital-twin/snapshots');
}

export async function simulate(
snapshotId: number,
adjustments: Record<string, number>,
): Promise<SimulationResult> {
return api<SimulationResult>('/digital-twin/simulate', {
method: 'POST',
body: { snapshot_id: snapshotId, adjustments },
});
}
2 changes: 2 additions & 0 deletions packages/backend/app/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .categories import bp as categories_bp
from .docs import bp as docs_bp
from .dashboard import bp as dashboard_bp
from .digital_twin import bp as digital_twin_bp


def register_routes(app: Flask):
Expand All @@ -18,3 +19,4 @@ def register_routes(app: Flask):
app.register_blueprint(categories_bp, url_prefix="/categories")
app.register_blueprint(docs_bp, url_prefix="/docs")
app.register_blueprint(dashboard_bp, url_prefix="/dashboard")
app.register_blueprint(digital_twin_bp, url_prefix="/digital-twin")
103 changes: 103 additions & 0 deletions packages/backend/app/routes/digital_twin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
import logging
from datetime import datetime
from flask import Blueprint, jsonify, request
from flask_jwt_extended import jwt_required, get_jwt_identity
from sqlalchemy import func
from ..extensions import db
from ..models import AuditLog, Expense, Bill

bp = Blueprint("digital_twin", __name__)
logger = logging.getLogger("finmind.digital_twin")


@bp.post("/create")
@jwt_required()
def create_snapshot():
uid = int(get_jwt_identity())

total_expenses = float(
db.session.query(func.coalesce(func.sum(Expense.amount), 0))
.filter(Expense.user_id == uid, Expense.expense_type != "INCOME")
.scalar()
)
total_income = float(
db.session.query(func.coalesce(func.sum(Expense.amount), 0))
.filter(Expense.user_id == uid, Expense.expense_type == "INCOME")
.scalar()
)
total_bills = float(
db.session.query(func.coalesce(func.sum(Bill.amount), 0))
.filter(Bill.user_id == uid, Bill.active.is_(True))
.scalar()
)

snapshot = {
"total_expenses": total_expenses,
"total_income": total_income,
"total_bills": total_bills,
"net_worth": round(total_income - total_expenses, 2),
"created_at": datetime.utcnow().isoformat(),
}
log = AuditLog(user_id=uid, action=f"digital_twin_snapshot:{json.dumps(snapshot)}")
db.session.add(log)
db.session.commit()
snapshot["snapshot_id"] = log.id
logger.info("Created digital twin snapshot user=%s", uid)
return jsonify(snapshot), 201


@bp.get("/snapshots")
@jwt_required()
def list_snapshots():
uid = int(get_jwt_identity())
rows = (
db.session.query(AuditLog)
.filter(AuditLog.user_id == uid, AuditLog.action.like("digital_twin_snapshot:%"))
.order_by(AuditLog.created_at.desc())
.all()
)
snapshots = []
for r in rows:
try:
data = json.loads(r.action.split(":", 1)[1])
data["snapshot_id"] = r.id
snapshots.append(data)
except (json.JSONDecodeError, IndexError):
continue
return jsonify(snapshots)


@bp.post("/simulate")
@jwt_required()
def simulate():
uid = int(get_jwt_identity())
data = request.get_json() or {}
snapshot_id = data.get("snapshot_id")
adjustments = data.get("adjustments", {})
if not snapshot_id:
return jsonify(error="snapshot_id required"), 400

log = db.session.query(AuditLog).filter(
AuditLog.id == snapshot_id,
AuditLog.user_id == uid,
AuditLog.action.like("digital_twin_snapshot:%"),
).first()
if not log:
return jsonify(error="snapshot not found"), 404

try:
snapshot = json.loads(log.action.split(":", 1)[1])
except (json.JSONDecodeError, IndexError):
return jsonify(error="corrupt snapshot"), 500

projected = {
"total_income": snapshot["total_income"] + adjustments.get("income_change", 0),
"total_expenses": snapshot["total_expenses"] + adjustments.get("expense_change", 0),
"total_bills": snapshot["total_bills"] + adjustments.get("bills_change", 0),
}
projected["net_worth"] = round(projected["total_income"] - projected["total_expenses"], 2)
projected["snapshot_id"] = snapshot_id
projected["adjustments"] = adjustments
logger.info("Simulated digital twin user=%s snapshot=%s", uid, snapshot_id)
return jsonify(projected)
76 changes: 76 additions & 0 deletions packages/backend/tests/test_digital_twin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from datetime import date


def test_digital_twin_requires_auth(client):
r = client.post("/digital-twin/create")
assert r.status_code in (401, 422)

r = client.get("/digital-twin/snapshots")
assert r.status_code in (401, 422)

r = client.post("/digital-twin/simulate", json={"snapshot_id": 1})
assert r.status_code in (401, 422)


def test_create_snapshot(client, auth_header):
r = client.post("/digital-twin/create", headers=auth_header)
assert r.status_code == 201
data = r.get_json()
assert "total_expenses" in data
assert "total_income" in data
assert "total_bills" in data
assert "net_worth" in data
assert "snapshot_id" in data


def test_list_snapshots(client, auth_header):
client.post("/digital-twin/create", headers=auth_header)
client.post("/digital-twin/create", headers=auth_header)

r = client.get("/digital-twin/snapshots", headers=auth_header)
assert r.status_code == 200
snapshots = r.get_json()
assert len(snapshots) >= 2
assert all("snapshot_id" in s for s in snapshots)


def test_simulate(client, auth_header):
# Seed some data
client.post(
"/expenses",
json={
"amount": 1000,
"description": "Salary",
"date": date.today().isoformat(),
"expense_type": "INCOME",
},
headers=auth_header,
)
client.post(
"/expenses",
json={
"amount": 200,
"description": "Groceries",
"date": date.today().isoformat(),
"expense_type": "EXPENSE",
},
headers=auth_header,
)

r = client.post("/digital-twin/create", headers=auth_header)
assert r.status_code == 201
snapshot_id = r.get_json()["snapshot_id"]

r = client.post(
"/digital-twin/simulate",
json={
"snapshot_id": snapshot_id,
"adjustments": {"income_change": 500, "expense_change": -100},
},
headers=auth_header,
)
assert r.status_code == 200
result = r.get_json()
assert result["total_income"] == 1500
assert result["total_expenses"] == 100
assert result["net_worth"] == 1400