diff --git a/.gitignore b/.gitignore index 1dbe456..1389767 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,14 @@ xcuserdata/ report.junit report.html *.xcresult + +# Local ML artifacts +checkpoints/ +data/ +HealthyLLM/Supporting Files/LocalLLM/model.safetensors +HealthyLLM/Supporting Files/OpenTSLM/sleep_cot.csv +HealthyLLM/Supporting Files/OpenTSLM/ecg_qa_cot_test.csv +HealthyLLM/Supporting Files/OpenTSLM/ecg_qa_waveforms/ +HealthyLLM/Supporting Files/OpenTSLM/ecg_qa_template_answers.json +HealthyLLM/Supporting Files/OpenTSLM/ecg_qa_waveforms.json +jsons/model.safetensors diff --git a/HealthyLLM.xcodeproj/project.pbxproj b/HealthyLLM.xcodeproj/project.pbxproj index d554b4f..8ca439a 100644 --- a/HealthyLLM.xcodeproj/project.pbxproj +++ b/HealthyLLM.xcodeproj/project.pbxproj @@ -544,7 +544,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"HealthyLLM/Supporting Files/Preview Content\""; - DEVELOPMENT_TEAM = CQRZ4E7K9U; + DEVELOPMENT_TEAM = V5G338H5V6; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_ITSAppUsesNonExemptEncryption = NO; @@ -586,7 +586,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"HealthyLLM/Supporting Files/Preview Content\""; - DEVELOPMENT_TEAM = CQRZ4E7K9U; + DEVELOPMENT_TEAM = V5G338H5V6; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_ITSAppUsesNonExemptEncryption = NO; @@ -628,7 +628,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"HealthyLLMStudy/Supporting Files/Preview Content\""; - DEVELOPMENT_TEAM = CQRZ4E7K9U; + DEVELOPMENT_TEAM = C496LC49DH; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = "HealthyLLMStudy/Supporting Files/Info.plist"; @@ -653,7 +653,7 @@ "@executable_path/Frameworks", ); MARKETING_VERSION = 1.0; - PRODUCT_BUNDLE_IDENTIFIER = edu.maxrosenblattl.bdhg.healthyllm; + PRODUCT_BUNDLE_IDENTIFIER = edu.maxrosenblattl.bdhg.healthyllm.co; PRODUCT_NAME = "$(TARGET_NAME)"; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator"; SUPPORTS_MACCATALYST = NO; @@ -675,7 +675,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"HealthyLLMStudy/Supporting Files/Preview Content\""; - DEVELOPMENT_TEAM = CQRZ4E7K9U; + DEVELOPMENT_TEAM = C496LC49DH; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = "HealthyLLMStudy/Supporting Files/Info.plist"; @@ -700,7 +700,7 @@ "@executable_path/Frameworks", ); MARKETING_VERSION = 1.0; - PRODUCT_BUNDLE_IDENTIFIER = edu.maxrosenblattl.bdhg.healthyllm; + PRODUCT_BUNDLE_IDENTIFIER = edu.maxrosenblattl.bdhg.healthyllm.co; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator"; diff --git a/HealthyLLM.xcodeproj/xcshareddata/xcschemes/HealthyLLM.xcscheme b/HealthyLLM.xcodeproj/xcshareddata/xcschemes/HealthyLLM.xcscheme index 739a213..23a7000 100644 --- a/HealthyLLM.xcodeproj/xcshareddata/xcschemes/HealthyLLM.xcscheme +++ b/HealthyLLM.xcodeproj/xcshareddata/xcschemes/HealthyLLM.xcscheme @@ -79,6 +79,38 @@ ReferencedContainer = "container:HealthyLLM.xcodeproj"> + + + + + + + + + + + + + + [ElectrocardiogramData] { + let samples = try await fetchElectrocardiogramSamples(healthKit, limit: limit) + + var records: [ElectrocardiogramData] = [] + for sample in samples { + let voltages = try await fetchVoltageMeasurements(for: sample, healthKit: healthKit) + records.append( + .init( + startDate: sample.startDate, + endDate: sample.endDate, + classification: String(describing: sample.classification), + symptomsStatus: String(describing: sample.symptomsStatus), + averageHeartRate: sample.averageHeartRate?.doubleValue( + for: .count().unitDivided(by: .minute()) + ), + samplingFrequency: sample.samplingFrequency?.doubleValue(for: .hertz()), + numberOfVoltageMeasurements: voltages.count, + voltages: voltages + ) + ) + } + + return records + } + + private func fetchElectrocardiogramSamples( + _ healthKit: HealthKit, + limit: Int + ) async throws -> [HKElectrocardiogram] { + try await withCheckedThrowingContinuation { continuation in + let query = HKSampleQuery( + sampleType: HKObjectType.electrocardiogramType(), + predicate: nil, + limit: limit, + sortDescriptors: [NSSortDescriptor(key: HKSampleSortIdentifierStartDate, ascending: false)] + ) { _, samples, error in + if let error { + continuation.resume(throwing: error) + return + } + + continuation.resume(returning: samples as? [HKElectrocardiogram] ?? []) + } + + healthKit.healthStore.execute(query) + } + } + + private func fetchVoltageMeasurements( + for sample: HKElectrocardiogram, + healthKit: HealthKit + ) async throws -> [Double] { + try await withCheckedThrowingContinuation { continuation in + var voltages: [Double] = [] + var didResume = false + + let query = HKElectrocardiogramQuery(electrocardiogram: sample) { _, measurement, done, error in + if let error, !didResume { + didResume = true + continuation.resume(throwing: error) + return + } + + if let measurement, + let voltageQuantity = measurement.quantity(for: .appleWatchSimilarToLeadI) { + voltages.append(voltageQuantity.doubleValue(for: .volt())) + } + + if done, !didResume { + didResume = true + continuation.resume(returning: voltages) + } + } + + healthKit.healthStore.execute(query) + } + } +} \ No newline at end of file diff --git a/HealthyLLM/HealthyLLM/Fetcher/HealthDataFetcher.swift b/HealthyLLM/HealthyLLM/Fetcher/HealthDataFetcher.swift index ebcfc18..8f88799 100644 --- a/HealthyLLM/HealthyLLM/Fetcher/HealthDataFetcher.swift +++ b/HealthyLLM/HealthyLLM/Fetcher/HealthDataFetcher.swift @@ -17,7 +17,8 @@ class HealthDataFetcher: DefaultInitializable, Module, EnvironmentAccessible { let readTypes = Set([ HKSeriesType.activitySummaryType(), HKSeriesType.workoutRoute(), - HKSeriesType.workoutType() + HKSeriesType.workoutType(), + HKObjectType.electrocardiogramType() ]).union( Set(allHKQuantityTypeIdentifiers().map { HKQuantityType($0) }) ) diff --git a/HealthyLLM/HealthyLLM/HealthContextGenerator.swift b/HealthyLLM/HealthyLLM/HealthContextGenerator.swift index feca36a..149152e 100644 --- a/HealthyLLM/HealthyLLM/HealthContextGenerator.swift +++ b/HealthyLLM/HealthyLLM/HealthContextGenerator.swift @@ -11,4 +11,92 @@ import Spezi class HealthContextGenerator: DefaultInitializable, Module, EnvironmentAccessible { required init() { } + + func buildSystemPrompt(userInfo: UserInfo?, electrocardiograms: [ElectrocardiogramData]) -> HealthyLLMContextEntity { + var prompt = """ + You are a health assistant. Provide a careful, non-diagnostic interpretation of the user's health data. + If ECG data is present, focus on signal quality, rhythm regularity, heart rate, and any notable abnormalities. + Do not claim to diagnose a condition. + """ + + if let userInfo { + prompt += "\n\nUser profile:\n" + prompt += userInfo.asJSONRepresentation(.prettyPrinted) ?? "No Data" + } + + if !electrocardiograms.isEmpty { + prompt += "\n\nLatest ECG samples:\n" + prompt += electrocardiograms.enumerated().map { index, sample in + formatECGSample(sample, index: index + 1) + }.joined(separator: "\n\n") + } else { + prompt += "\n\nLatest ECG samples: No Data" + } + + prompt += "\n\nRespond with a concise clinical-style summary and a safety note if the tracing looks concerning." + + return .init(.system, content: prompt) + } + + private func formatECGSample(_ sample: ElectrocardiogramData, index: Int) -> String { + let voltages = sample.voltages + let average = voltages.isEmpty ? nil : voltages.reduce(0, +) / Double(voltages.count) + let minimum = voltages.min() + let maximum = voltages.max() + let preview = voltages.prefix(24).map { String(format: "%.4f", $0) }.joined(separator: ", ") + + let normalizedVoltages = zNormalize(voltages) + let normalizedPreview = normalizedVoltages.prefix(48).map { String(format: "%.6f", $0) }.joined(separator: ", ") + + let averageHeartRateText = sample.averageHeartRate.map(String.init(describing:)) ?? "No Data" + let samplingFrequencyText = sample.samplingFrequency.map(String.init(describing:)) ?? "No Data" + let voltageMeanText = average.map(String.init(describing:)) ?? "No Data" + let voltageMinText = minimum.map(String.init(describing:)) ?? "No Data" + let voltageMaxText = maximum.map(String.init(describing:)) ?? "No Data" + + let sleepCotStylePrompt = """ + sleep_cot_style_sample: + pre_prompt: You are given a short single-lead ECG time series segment. Analyze rhythm, signal quality, and notable concerns conservatively. + time_series_text: + - The following is the ECG time series with mean \(String(format: "%.6f", average ?? 0)) and min/max \(String(format: "%.6f", minimum ?? 0))/\(String(format: "%.6f", maximum ?? 0)). + time_series_normalized_preview: [\(normalizedPreview)] + post_prompt: First summarize waveform quality and rhythm regularity, then provide brief safety guidance and when to seek care. + """ + + var lines: [String] = [ + "ECG sample #\(index)", + "start_date: \(sample.startDate.formatted(date: .abbreviated, time: .shortened))", + "end_date: \(sample.endDate.formatted(date: .abbreviated, time: .shortened))", + "classification: \(sample.classification)", + "symptoms_status: \(sample.symptomsStatus)", + "average_heart_rate: \(averageHeartRateText)", + "sampling_frequency_hz: \(samplingFrequencyText)", + "voltage_count: \(sample.numberOfVoltageMeasurements)", + "voltage_mean: \(voltageMeanText)", + "voltage_min: \(voltageMinText)", + "voltage_max: \(voltageMaxText)" + ] + + if !preview.isEmpty { + lines.append("voltage_preview: [\(preview)]") + } + + lines.append(sleepCotStylePrompt) + + return lines.joined(separator: "\n") + } + + private func zNormalize(_ values: [Double]) -> [Double] { + guard !values.isEmpty else { + return [] + } + + let mean = values.reduce(0, +) / Double(values.count) + let variance = values.reduce(0) { partial, value in + let delta = value - mean + return partial + delta * delta + } / Double(values.count) + let std = max(sqrt(variance), 1e-6) + return values.map { ($0 - mean) / std } + } } diff --git a/HealthyLLM/HealthyLLM/HealthDataInterpreter.swift b/HealthyLLM/HealthyLLM/HealthDataInterpreter.swift index 9261ef0..504280f 100644 --- a/HealthyLLM/HealthyLLM/HealthDataInterpreter.swift +++ b/HealthyLLM/HealthyLLM/HealthDataInterpreter.swift @@ -8,27 +8,53 @@ import Foundation import HealthKit +import Hub import OSLog import Spezi import SpeziChat import SpeziHealthKit import SpeziHealthKitUI +import MLX +import MLXLMCommon +import MLXLLM import SpeziLLM import SpeziLLMLocal @Observable class HealthDataInterpreter: DefaultInitializable, Module, EnvironmentAccessible { @ObservationIgnored private let logger = Logger(subsystem: "HealthyLLM", category: "HealthDataInterpreter") - + + enum LoadingStage: String { + case idle = "Idle" + case stagingModel = "Staging local model files" + case configuringParameters = "Configuring LLM parameters" + case creatingSession = "Creating LLM session" + case openingSession = "Opening LLM session (loading weights)" + case ready = "Ready" + case failed = "Failed" + } + private(set) var loaded = false + private(set) var loadingStage: LoadingStage = .idle + private(set) var loadingDetail: String = "" @ObservationIgnored @Dependency(LLMRunner.self) private var llmRunner: LLMRunner @ObservationIgnored @Dependency(HealthDataFetcher.self) private var healthDataFetcher: HealthDataFetcher + @ObservationIgnored @Dependency(HealthContextGenerator.self) private var healthContextGenerator: HealthContextGenerator + @ObservationIgnored @Dependency(OpenTSLMInferenceService.self) private var openTSLMInferenceService: OpenTSLMInferenceService @ObservationIgnored private var functionCallParameters: LLMLocalParameters? @ObservationIgnored private var functionCallSamplingParameters: LLMLocalSamplingParameters? @ObservationIgnored private var defaultParameters: LLMLocalParameters? + @ObservationIgnored private var defaultSamplingParameters: LLMLocalSamplingParameters? @ObservationIgnored private var sharedSession: LLMLocalSession? + @ObservationIgnored private let requiredModelFiles = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "model.safetensors" + ] private(set) var context: HealthyLLMContext = [] private(set) var advancedContext: HealthyLLMContext = [] @@ -36,37 +62,461 @@ class HealthDataInterpreter: DefaultInitializable, Module, EnvironmentAccessible required init() { } func setup() async throws { + logger.info("setup(): starting. modelID=\(Constants.llmModelName, privacy: .public) destination=\(Constants.llmLocalModelDirectory.path, privacy: .public)") + + // Make the MLX factory build our embedding-capable Llama for "llama"/"mistral" + // model types, so the session's model can be primed with OpenTSLM soft-prompt + // embeddings. Idempotent; must run before any LLMModelFactory.shared.loadContainer. + EmbeddingLlamaModelRegistration.register() + + await MainActor.run { + loadingStage = .stagingModel + loadingDetail = Constants.llmModelName + } + + if Constants.skipLLMLoad { + logger.info("setup(): skipping Llama staging/load (HEALTHYLLM_SKIP_LLM_LOAD=1)") + await MainActor.run { + loaded = true + loadingStage = .ready + loadingDetail = "Encoder-only (Llama skipped)" + } + return + } + + do { + try await stageLocalModelIfNeeded() + try removeNonBaseWeightSafetensorsFromModelDirectory() + } catch { + logger.error("setup(): stageLocalModelIfNeeded threw: \(error.localizedDescription, privacy: .public)") + await MainActor.run { + loadingStage = .failed + loadingDetail = "Staging failed: \(error.localizedDescription)" + } + throw error + } + + await MainActor.run { + loadingStage = .configuringParameters + } + logger.info("setup(): staging complete; configuring parameters") + + let chatTemplate: String? = Constants.useCustomChatTemplate ? Constants.llmModelChatTemplate : nil + logger.info("setup(): chatTemplate=\(chatTemplate == nil ? "tokenizer-default" : "custom-jinja", privacy: .public)") + functionCallParameters = .init( maxOutputLength: 32, - chatTemplate: Constants.llmModelChatTemplate + chatTemplate: chatTemplate ) functionCallSamplingParameters = .init( topP: 1.0, temperature: 0.001, penaltyRepeat: 1.3 ) - + defaultParameters = .init( - maxOutputLength: 1024, - chatTemplate: Constants.llmModelChatTemplate + maxOutputLength: Constants.llmDefaultMaxOutputLength, + chatTemplate: chatTemplate + ) + defaultSamplingParameters = .init( + topP: 1.0, + temperature: 0.7, + penaltyRepeat: 1.2 ) guard let defaultParameters else { + logger.error("setup(): defaultParameters unexpectedly nil after assignment") + await MainActor.run { + loadingStage = .failed + loadingDetail = "defaultParameters nil" + } return } - + + await MainActor.run { + loadingStage = .creatingSession + } + logger.info("setup(): creating LLMLocalSchema and session") + let schema = LLMLocalSchema( model: .custom(id: Constants.llmModelName), parameters: defaultParameters, + samplingParameters: defaultSamplingParameters ?? .init(), injectIntoContext: true ) - + sharedSession = llmRunner.callAsFunction(with: schema) guard let sharedSession else { + logger.error("setup(): llmRunner.callAsFunction returned nil session") + await MainActor.run { + loadingStage = .failed + loadingDetail = "llmRunner returned nil session" + } + return + } + + await MainActor.run { + loadingStage = .openingSession + loadingDetail = "Loading weights — this can take a while on first launch" + } + let setupStart = Date() + let fileManager = FileManager.default + let modelDirectory = Constants.llmLocalModelDirectory + + if hasRequiredModelFiles(in: modelDirectory, fileManager: fileManager) { + logger.info("setup(): loading MLX container from local directory (bf16 base weights only)") + do { + let container = try await LLMModelFactory.shared.loadContainer( + configuration: ModelConfiguration(directory: modelDirectory) + ) + await MainActor.run { + sharedSession.modelContainer = container + sharedSession.state = .ready + } + let setupDuration = Date().timeIntervalSince(setupStart) + logger.info("setup(): direct loadContainer succeeded in \(setupDuration, privacy: .public)s") + await MainActor.run { + loaded = true + loadingStage = .ready + loadingDetail = String(format: "Loaded in %.1fs", setupDuration) + } + return + } catch { + logger.error("setup(): direct loadContainer failed: \(String(reflecting: error), privacy: .public)") + } + } + + logger.info("setup(): calling sharedSession.setup() — Hub snapshot fallback") + + do { + try await sharedSession.setup() + } catch { + logger.error("setup(): sharedSession.setup() threw after \(Date().timeIntervalSince(setupStart), privacy: .public)s: \(error.localizedDescription, privacy: .public)") + await MainActor.run { + loadingStage = .failed + loadingDetail = "Session setup failed: \(error.localizedDescription)" + } + throw HealthDataInterpreterError.modelNotLoaded + } + + let setupDuration = Date().timeIntervalSince(setupStart) + logger.info("setup(): sharedSession.setup() completed in \(setupDuration, privacy: .public)s") + + await MainActor.run { + loaded = true + loadingStage = .ready + loadingDetail = String(format: "Loaded in %.1fs", setupDuration) + } + } + + private func stageLocalModelIfNeeded() async throws { + let fileManager = FileManager.default + let destinationURL = Constants.llmLocalModelDirectory + logger.info("stageLocalModelIfNeeded: destination=\(destinationURL.path, privacy: .public)") + + do { + try fileManager.createDirectory(at: destinationURL, withIntermediateDirectories: true) + } catch { + logger.error("Failed creating local model destination directory: \(error.localizedDescription, privacy: .public)") + return + } + + try stageOpenTSLMLoRACheckpointIfNeeded() + + if hasRequiredModelFiles(in: destinationURL, fileManager: fileManager) { + logger.info("stageLocalModelIfNeeded: destination already has all required files (\(self.requiredModelFiles.joined(separator: ", "), privacy: .public)) — skipping copy") + try removeNonBaseWeightSafetensorsFromModelDirectory() + return + } + + let missing = requiredModelFiles.filter { fileName in + !fileManager.fileExists(atPath: destinationURL.appendingPathComponent(fileName).path) + } + logger.info("stageLocalModelIfNeeded: destination is missing files: \(missing.joined(separator: ", "), privacy: .public)") + + guard let sourceURL = resolveLocalModelSourceDirectory() else { + logger.warning("stageLocalModelIfNeeded: no local model source directory found. The app will rely on the download flow / Hub cache. Required files still missing at destination.") + return + } + logger.info("stageLocalModelIfNeeded: copying from \(sourceURL.path, privacy: .public)") + + do { + try copyRequiredModelFiles(from: sourceURL, to: destinationURL) + try removeNonBaseWeightSafetensorsFromModelDirectory() + + if hasRequiredModelFiles(in: destinationURL, fileManager: fileManager) { + logger.info("stageLocalModelIfNeeded: staged local model from \(sourceURL.path, privacy: .public) to \(destinationURL.path, privacy: .public)") + } else { + let stillMissing = requiredModelFiles.filter { fileName in + !fileManager.fileExists(atPath: destinationURL.appendingPathComponent(fileName).path) + } + logger.error("stageLocalModelIfNeeded: staging finished but required files still missing: \(stillMissing.joined(separator: ", "), privacy: .public)") + } + } catch { + logger.error("stageLocalModelIfNeeded: failed: \(error.localizedDescription, privacy: .public)") + throw error + } + } + + /// Stage LoRA under ``Constants/openTSLMDocumentsDirectory`` — never into the Llama HF folder (MLX would load it as base weights). + private func stageOpenTSLMLoRACheckpointIfNeeded() throws { + let fileManager = FileManager.default + let destinationDirectory = Constants.openTSLMDocumentsDirectory + let destinationLoRA = destinationDirectory + .appendingPathComponent("\(Constants.openTSLMLoRACheckpointName).safetensors") + + if fileManager.fileExists(atPath: destinationLoRA.path) { + return + } + + guard let sourceLoRA = resolveLoRACheckpointSource() else { + if Constants.requireLoRACheckpoint { + throw NSError( + domain: "HealthDataInterpreter", + code: 10, + userInfo: [NSLocalizedDescriptionKey: "HEALTHYLLM_REQUIRE_LORA=1 but no LoRA checkpoint was found. Set HEALTHYLLM_OPEN_TSLM_LORA_CHECKPOINT or bundle \(Constants.openTSLMLoRACheckpointName).safetensors under OpenTSLM/."] + ) + } + return + } + + try fileManager.createDirectory(at: destinationDirectory, withIntermediateDirectories: true) + try fileManager.copyItem(at: sourceLoRA, to: destinationLoRA) + logger.info("LoRA checkpoint staged at \(destinationLoRA.path, privacy: .public)") + } + + private func resolveLoRACheckpointSource() -> URL? { + let fileManager = FileManager.default + var candidates: [URL] = [] + + if !Constants.openTSLMLoRACheckpointPath.isEmpty { + candidates.append(URL(fileURLWithPath: Constants.openTSLMLoRACheckpointPath)) + } + + if let bundled = Bundle.main.url( + forResource: Constants.openTSLMLoRACheckpointName, + withExtension: "safetensors", + subdirectory: Constants.openTSLMBundleSubdirectory + ) { + candidates.append(bundled) + } + + if let bundleRoot = Bundle.main.resourceURL { + candidates.append( + bundleRoot + .appendingPathComponent(Constants.openTSLMBundleSubdirectory, isDirectory: true) + .appendingPathComponent("\(Constants.openTSLMLoRACheckpointName).safetensors") + ) + } + + return candidates.first { fileManager.fileExists(atPath: $0.path) } + } + + /// Remove adapter / OpenTSLM safetensors from the Llama model directory so ``loadContainer`` only sees base weights. + private func removeNonBaseWeightSafetensorsFromModelDirectory() throws { + let fileManager = FileManager.default + let modelDirectory = Constants.llmLocalModelDirectory + guard let items = try? fileManager.contentsOfDirectory(at: modelDirectory, includingPropertiesForKeys: nil) else { return } - try await sharedSession.setup() - loaded = true + for item in items where item.pathExtension == "safetensors" { + let name = item.lastPathComponent + let isBaseWeight = name == "model.safetensors" + || (name.hasPrefix("model-") && name.hasSuffix(".safetensors")) + guard !isBaseWeight else { + continue + } + try fileManager.removeItem(at: item) + logger.info("Removed non-base safetensors from model dir: \(name, privacy: .public)") + } + } + + /// Copy only Llama base checkpoint files — never OpenTSLM encoder/projector/LoRA weights. + private func copyRequiredModelFiles(from sourceURL: URL, to destinationURL: URL) throws { + let fileManager = FileManager.default + + for fileName in requiredModelFiles { + let sourceFile = sourceURL.appendingPathComponent(fileName) + let destinationFile = destinationURL.appendingPathComponent(fileName) + guard fileManager.fileExists(atPath: sourceFile.path) else { + continue + } + if fileManager.fileExists(atPath: destinationFile.path) { + continue + } + try fileManager.copyItem(at: sourceFile, to: destinationFile) + } + + let sourceItems = try fileManager.contentsOfDirectory(at: sourceURL, includingPropertiesForKeys: nil) + for item in sourceItems where item.pathExtension == "safetensors" { + let name = item.lastPathComponent + let isBaseWeight = name == "model.safetensors" + || (name.hasPrefix("model-") && name.hasSuffix(".safetensors")) + guard isBaseWeight else { + continue + } + let destinationFile = destinationURL.appendingPathComponent(name) + if !fileManager.fileExists(atPath: destinationFile.path) { + try fileManager.copyItem(at: item, to: destinationFile) + } + } + } + + private func hasRequiredModelFiles(in directoryURL: URL, fileManager: FileManager) -> Bool { + requiredModelFiles.allSatisfy { fileName in + fileManager.fileExists(atPath: directoryURL.appendingPathComponent(fileName).path) + } + } + + private func resolveLocalModelSourceDirectory() -> URL? { + let fileManager = FileManager.default + + if let overridePath = Constants.localModelSourcePathOverride, + let overrideURL = existingDirectoryURL(at: overridePath, fileManager: fileManager) { + return overrideURL + } + + if let bundledLocalModelURL = Bundle.main.resourceURL { + let bundledDirectory = bundledLocalModelURL.appendingPathComponent(Constants.localModelBundleSubdirectory, isDirectory: true) + if fileManager.fileExists(atPath: bundledDirectory.path) { + return bundledDirectory + } + + let bundledRootFiles = requiredModelFiles.allSatisfy { fileName in + fileManager.fileExists(atPath: bundledLocalModelURL.appendingPathComponent(fileName).path) + } + + if bundledRootFiles { + return bundledLocalModelURL + } + } + + // Fallback: detect a downloaded model snapshot from Hugging Face cache. + let sanitizedRepoID = Constants.llmModelName.replacingOccurrences(of: "/", with: "--") + let hostSnapshotsPath = "\(Constants.hostHuggingFaceCacheRoot)/models--\(sanitizedRepoID)/snapshots" + let hostSnapshotsURL = URL(fileURLWithPath: hostSnapshotsPath, isDirectory: true) + + if let hostSnapshot = newestSnapshotDirectory(in: hostSnapshotsURL, fileManager: fileManager) { + return hostSnapshot + } + + let snapshotsPath = "~/.cache/huggingface/hub/models--\(sanitizedRepoID)/snapshots" + let snapshotsURL = URL(fileURLWithPath: NSString(string: snapshotsPath).expandingTildeInPath, isDirectory: true) + return newestSnapshotDirectory(in: snapshotsURL, fileManager: fileManager) + } + + private func newestSnapshotDirectory(in snapshotsURL: URL, fileManager: FileManager) -> URL? { + guard fileManager.fileExists(atPath: snapshotsURL.path) else { + return nil + } + + let directoryContents = try? fileManager.contentsOfDirectory( + at: snapshotsURL, + includingPropertiesForKeys: [.contentModificationDateKey], + options: [.skipsHiddenFiles] + ) + + guard let directoryContents else { + return nil + } + + var newestURL: URL? + var newestDate = Date.distantPast + + for url in directoryContents { + var isDirectory: ObjCBool = false + guard fileManager.fileExists(atPath: url.path, isDirectory: &isDirectory), isDirectory.boolValue else { + continue + } + let modified = (try? url.resourceValues(forKeys: [.contentModificationDateKey]).contentModificationDate) ?? .distantPast + if modified > newestDate { + newestDate = modified + newestURL = url + } + } + + return newestURL + } + + private func existingDirectoryURL(at rawPath: String, fileManager: FileManager) -> URL? { + let expandedPath = NSString(string: rawPath).expandingTildeInPath + let url = URL(fileURLWithPath: expandedPath, isDirectory: true) + var isDirectory: ObjCBool = false + + guard fileManager.fileExists(atPath: url.path, isDirectory: &isDirectory), isDirectory.boolValue else { + return nil + } + + return url + } + + private func copyDirectoryContents(from sourceURL: URL, to destinationURL: URL) throws { + let fileManager = FileManager.default + let items = try fileManager.contentsOfDirectory(at: sourceURL, includingPropertiesForKeys: nil) + + for item in items { + let destinationItem = destinationURL.appendingPathComponent(item.lastPathComponent) + var isDirectory: ObjCBool = false + let exists = fileManager.fileExists(atPath: item.path, isDirectory: &isDirectory) + + guard exists else { + continue + } + + if isDirectory.boolValue { + try fileManager.createDirectory(at: destinationItem, withIntermediateDirectories: true) + try copyDirectoryContents(from: item, to: destinationItem) + } else if !fileManager.fileExists(atPath: destinationItem.path) { + try fileManager.copyItem(at: item, to: destinationItem) + } + } + } + + private func ecgSamplesForPrompt(_ healthKit: HealthKit) async -> [ElectrocardiogramData] { + let fetched = (try? await healthDataFetcher.fetchElectrocardiograms(healthKit, limit: 1)) ?? [] + + if Constants.includeHardcodedECGSample { + var merged = fetched + merged.append(hardcodedECGSample()) + return merged + } + + if fetched.isEmpty { + return [] + } + + return fetched + } + + private func hardcodedECGSample() -> ElectrocardiogramData { + let samplingFrequency = 256.0 + let sampleCount = Constants.hardcodedECGSampleLength + let durationSeconds = Double(sampleCount) / samplingFrequency + let endDate = Date() + let startDate = endDate.addingTimeInterval(-durationSeconds) + + let voltages: [Double] = (0 ..< sampleCount).map { index in + let t = Double(index) / samplingFrequency + + // Synthetic ECG-like waveform with deterministic QRS spikes. + let base = 0.025 * sin(2.0 * .pi * 1.2 * t) + let pWave = 0.010 * sin(2.0 * .pi * 4.0 * t) + let qrsPhase = t.remainder(dividingBy: 0.86) + let qrs = qrsPhase < 0.018 ? 0.72 * exp(-pow((qrsPhase - 0.006) * 120.0, 2.0)) : 0.0 + let tWave = 0.040 * exp(-pow((qrsPhase - 0.24) * 14.0, 2.0)) + return base + pWave + qrs + tWave + } + + return .init( + startDate: startDate, + endDate: endDate, + classification: "sinusRhythm_sample", + symptomsStatus: "notSet_sample", + averageHeartRate: 70.0, + samplingFrequency: samplingFrequency, + numberOfVoltageMeasurements: voltages.count, + voltages: voltages + ) } func queryLLM(with context: Chat, healthKit: HealthKit) async throws { @@ -83,9 +533,139 @@ class HealthDataInterpreter: DefaultInitializable, Module, EnvironmentAccessible self.context.append(.init(.user, content: userPrompt.content)) self.advancedContext.append(.init(.user, content: userPrompt.content)) - - try await checkForFunctionCall(prompt: userPrompt.content, healthKit: healthKit) - try await defaultResponse(healthKit) + + if shouldRunOpenTSLMSampleInference(for: userPrompt.content) { + await prepareForOpenTSLMSampleInference(keeping: userPrompt.content) + do { + let inferenceResult = try await openTSLMInferenceService.runSleepSampleInference( + llmRunner: llmRunner, + llmSession: sharedSession + ) + let reply = """ + I ran the OpenTSLM sample inference directly in the iOS app using your local checkpoints and sleep_cot sample data. + + \(inferenceResult) + """ + self.context.append(.init(.assistant, content: reply, completed: true)) + self.advancedContext.append(.init(.assistant, content: reply, completed: true)) + } catch { + let failure = "OpenTSLM sample inference failed in-app: \(error.localizedDescription)" + self.context.append(.init(.assistant, content: failure, completed: true)) + self.advancedContext.append(.init(.assistant, content: failure, completed: true)) + } + return + } + + if shouldRunOpenTSLMECGSampleInference(for: userPrompt.content) { + await prepareForOpenTSLMSampleInference(keeping: userPrompt.content) + do { + let useLLM = Constants.openTSLMRunSampleLLMGeneration && !Constants.skipLLMLoad + let inferenceResult = try await openTSLMInferenceService.runECGSampleInference( + llmRunner: useLLM ? llmRunner : nil, + llmSession: useLLM ? sharedSession : nil + ) + let reply = """ + I ran the OpenTSLM ECG-QA CoT sample path directly in the iOS app using the bundled CoT CSV + PTB-XL waveform sidecar (same loader indexing as Python). + + \(inferenceResult) + """ + self.context.append(.init(.assistant, content: reply, completed: true)) + self.advancedContext.append(.init(.assistant, content: reply, completed: true)) + } catch { + let failure = "OpenTSLM ECG sample inference failed in-app: \(error.localizedDescription)" + self.context.append(.init(.assistant, content: failure, completed: true)) + self.advancedContext.append(.init(.assistant, content: failure, completed: true)) + } + return + } + + // The ECG feature (auto-prompt) runs the OpenTSLM soft-prompt model on the user's + // actual HealthKit recording, rather than the generic chat path. + if userPrompt.content == Constants.ecgAutoPrompt { + logger.info("queryLLM: routing ECG auto-prompt to OpenTSLM") + do { + let ecgSamples = await ecgSamplesForPrompt(healthKit) + guard let ecg = ecgSamples.first(where: { !$0.voltages.isEmpty }) else { + logger.info("queryLLM: no ECG with voltages available (samples=\(ecgSamples.count, privacy: .public))") + let reply = "I couldn't find an ECG reading to analyze. Record one with the ECG app on your Apple Watch and try again." + self.context.append(.init(.assistant, content: reply, completed: true)) + self.advancedContext.append(.init(.assistant, content: reply, completed: true)) + return + } + logger.info("queryLLM: ECG selected voltages=\(ecg.voltages.count, privacy: .public); running OpenTSLM (on-device, no streaming — may take ~15-30s)") + let analysis = try await openTSLMInferenceService.runECGInference( + voltages: ecg.voltages, + samplingFrequency: ecg.samplingFrequency ?? 512.0, + classification: ecg.classification, + symptomsStatus: ecg.symptomsStatus, + averageHeartRate: ecg.averageHeartRate, + llmRunner: llmRunner, + llmSession: sharedSession + ) + logger.info("queryLLM: ECG analysis returned \(analysis.count, privacy: .public) chars") + self.context.append(.init(.assistant, content: analysis, completed: true)) + self.advancedContext.append(.init(.assistant, content: analysis, completed: true)) + } catch { + logger.error("queryLLM: OpenTSLM ECG analysis failed: \(String(reflecting: error), privacy: .public)") + let failure = "ECG analysis failed: \(error.localizedDescription)" + self.context.append(.init(.assistant, content: failure, completed: true)) + self.advancedContext.append(.init(.assistant, content: failure, completed: true)) + } + return + } + + if userPrompt.content != Constants.ecgAutoPrompt { + do { + try await checkForFunctionCall(prompt: userPrompt.content, healthKit: healthKit) + } catch { + logger.error("queryLLM: checkForFunctionCall threw \(String(reflecting: error), privacy: .public) — localizedDescription=\(error.localizedDescription, privacy: .public)") + throw error + } + releaseLLMSessionBetweenGenerations() + } + + do { + try await defaultResponse(healthKit) + } catch { + logger.error("queryLLM: defaultResponse threw \(String(reflecting: error), privacy: .public) — localizedDescription=\(error.localizedDescription, privacy: .public)") + throw error + } + } + + /// Ends any in-flight generation and clears MLX GPU cache before a new `generate()` call. + private func releaseLLMSessionBetweenGenerations() { + sharedSession?.cancel() + GPU.clearCache() + } + + /// Drop prior chat / HealthKit system context so OpenTSLM sample inference runs with minimal memory. + private func prepareForOpenTSLMSampleInference(keeping prompt: String) async { + let command = HealthyLLMContextEntity(.user, content: prompt) + context = [command] + advancedContext = [command] + sharedSession?.cancel() + await MainActor.run { + sharedSession?.customContext = [] + } + GPU.clearCache() + } + + private func shouldRunOpenTSLMSampleInference(for prompt: String) -> Bool { + let normalized = prompt + .trimmingCharacters(in: .whitespacesAndNewlines) + .lowercased() + + // Keep this behind an explicit command so ECG/health prompts are not hijacked + // by the sample Sleep-EDF demo path. + return normalized == "/opentslm-sleep-sample" + } + + private func shouldRunOpenTSLMECGSampleInference(for prompt: String) -> Bool { + let normalized = prompt + .trimmingCharacters(in: .whitespacesAndNewlines) + .lowercased() + + return normalized == Constants.openTSLMECGSampleCommand } func resetChat() async { @@ -114,11 +694,20 @@ class HealthDataInterpreter: DefaultInitializable, Module, EnvironmentAccessible .init(.user, content: prompt) ] advancedContext.append(contentsOf: context) - - let functionCallLLMOutput = try await llmRunner.oneShot( - on: sharedSession, - customContext: context.map(\.asDictionary) - ) + + let dictContext = context.map(\.asDictionary) + logger.info("checkForFunctionCall: sending \(dictContext.count) messages to LLM. roles=\(dictContext.map { $0["role"] ?? "?" }.joined(separator: ","), privacy: .public)") + + let functionCallLLMOutput: String + do { + functionCallLLMOutput = try await llmRunner.oneShot( + on: sharedSession, + customContext: dictContext + ) + } catch { + logger.error("checkForFunctionCall: oneShot threw \(String(reflecting: error), privacy: .public)") + throw error + } logger.debug("Function Call LLM Finished with: \(functionCallLLMOutput)") @@ -139,44 +728,46 @@ class HealthDataInterpreter: DefaultInitializable, Module, EnvironmentAccessible sharedSession.update( parameters: defaultParameters, + samplingParameters: defaultSamplingParameters, injectIntoContext: true ) - if !context.map(\.content).contains(PromptGenerator.systemPrompt) { + if !context.contains(where: { $0.role == .system }) { let userInfo = await healthDataFetcher.fetchUser(healthKit) - context.append(PromptGenerator.buildSystemPrompt(for: .default, userInfo: userInfo)) - advancedContext.append(PromptGenerator.buildSystemPrompt(for: .default, userInfo: userInfo)) + let electrocardiograms = await ecgSamplesForPrompt(healthKit) + let systemPrompt = healthContextGenerator.buildSystemPrompt( + userInfo: userInfo, + electrocardiograms: electrocardiograms + ) + context.insert(systemPrompt, at: 0) + advancedContext.insert(systemPrompt, at: 0) } await MainActor.run { sharedSession.customContext = context.map(\.asDictionary) } - logger.debug("defaultResponse: Started with context") - - for try await stringPiece in try await sharedSession.generate() { - logger.debug("defaultResponse: Received string piece: \(stringPiece)") - - guard let lastContextEntity = context.last, - lastContextEntity.role == .assistant else { - context.append(.init(.assistant, content: stringPiece, completed: false)) - advancedContext.append(.init(.assistant, content: stringPiece, completed: false)) - continue - } + logger.info("defaultResponse: generating (\(self.context.count, privacy: .public) messages)") - context[context.count - 1] = .init( - .assistant, - content: lastContextEntity.content + stringPiece, - completed: false, - id: lastContextEntity.id - ) - advancedContext[advancedContext.count - 1] = .init( - .assistant, - content: lastContextEntity.content + stringPiece, - completed: false, - id: advancedContext.last!.id - ) + releaseLLMSessionBetweenGenerations() + + var assistantOutput = "" + do { + for try await stringPiece in try await sharedSession.generate() { + assistantOutput += stringPiece + } + } catch { + logger.error("defaultResponse: generate threw \(String(reflecting: error), privacy: .public)") + throw error } + + sharedSession.cancel() + GPU.clearCache() + logger.info("defaultResponse: finished (\(assistantOutput.count, privacy: .public) chars)") + + let assistantMessage = HealthyLLMContextEntity(.assistant, content: assistantOutput, completed: true) + context.append(assistantMessage) + advancedContext.append(assistantMessage) } /// Returns a bool representing if a function call has been made diff --git a/HealthyLLM/HealthyLLM/HealthyLLMChatView.swift b/HealthyLLM/HealthyLLM/HealthyLLMChatView.swift index 7debfd1..9e3bd76 100644 --- a/HealthyLLM/HealthyLLM/HealthyLLMChatView.swift +++ b/HealthyLLM/HealthyLLM/HealthyLLMChatView.swift @@ -19,6 +19,9 @@ struct HealthyLLMChatView: View { let firstPrompt: String @State private var showErrorAlert = false @State private var errorMessage = "" + @State private var lastSubmittedUserMessageID: UUID? + @State private var isSubmittingPrompt = false + @State private var didSendInitialPrompt = false var body: some View { NavigationStack { @@ -33,12 +36,31 @@ struct HealthyLLMChatView: View { ) } } set: { newValue in + guard let userPrompt = newValue.last, + userPrompt.role == .user, + userPrompt.complete, + userPrompt.id != lastSubmittedUserMessageID, + !isSubmittingPrompt else { + return + } + + lastSubmittedUserMessageID = userPrompt.id + isSubmittingPrompt = true + Task { + defer { + Task { @MainActor in + isSubmittingPrompt = false + } + } + do { try await healthDataInterpreter.queryLLM(with: newValue, healthKit: healthKit) } catch { - showErrorAlert = true - errorMessage = "Error querying LLM: \(error.localizedDescription)" + await MainActor.run { + showErrorAlert = true + errorMessage = "Error querying LLM: \(error.localizedDescription)" + } } } } @@ -57,8 +79,20 @@ struct HealthyLLMChatView: View { } } .task { + guard !didSendInitialPrompt else { + return + } + + didSendInitialPrompt = true await healthDataInterpreter.resetChat() - contextBinding.wrappedValue.append(.init(role: .user, content: firstPrompt)) + + do { + let initialContext: Chat = [.init(role: .user, content: firstPrompt)] + try await healthDataInterpreter.queryLLM(with: initialContext, healthKit: healthKit) + } catch { + showErrorAlert = true + errorMessage = "Error querying LLM: \(error.localizedDescription)" + } } } .alert("ERROR_ALERT_TITLE", isPresented: $showErrorAlert) { diff --git a/HealthyLLM/HealthyLLM/HealthyLLMContext.swift b/HealthyLLM/HealthyLLM/HealthyLLMContext.swift index adbe3e5..1438dc4 100644 --- a/HealthyLLM/HealthyLLM/HealthyLLMContext.swift +++ b/HealthyLLM/HealthyLLM/HealthyLLMContext.swift @@ -21,8 +21,8 @@ struct HealthyLLMContextEntity: Identifiable { var inChatName: String { switch self { - case .assistant: return "assisstant" - case .toolCall: return "assisstant" + case .assistant: return "assistant" + case .toolCall: return "assistant" case .toolResponse: return "tool" case .user: return "user" case .system: return "system" diff --git a/HealthyLLM/HealthyLLM/HealthyLLMView.swift b/HealthyLLM/HealthyLLM/HealthyLLMView.swift index f2ce134..7239922 100644 --- a/HealthyLLM/HealthyLLM/HealthyLLMView.swift +++ b/HealthyLLM/HealthyLLM/HealthyLLMView.swift @@ -7,20 +7,31 @@ // import SwiftUI +import SpeziHealthKit +import Hub +import OSLog +import SpeziLLMLocal +import SpeziLLMLocalDownload extension String: Identifiable { public var id: Self { self } } struct HealthyLLMView: View { + private static let logger = Logger(subsystem: "HealthyLLM", category: "HealthyLLMView") + @AppStorage(StorageKeys.onboardingFlowComplete) var completedOnboardingFlow = false @Environment(HealthDataInterpreter.self) private var healthDataInterpreter + @Environment(HealthKit.self) private var healthKit @State private var showSettings = false @State private var showWelcome = false @State private var showErrorAlert = false @State private var errorMessage = "" @State private var firstPrompt: String? + @State private var showModelDownload = false + @State private var didAttemptInitialization = false + @State private var needsModelDownload = false var body: some View { NavigationStack { @@ -60,23 +71,93 @@ struct HealthyLLMView: View { self.showWelcome = false } } + .sheet(isPresented: $showModelDownload) { + LLMLocalDownloadView( + model: .custom(id: Constants.llmModelName), + downloadDescription: "Download the \(Constants.llmModelName) model from Hugging Face." + ) { + showModelDownload = false + Task { + await initializeInterpreterIfPossible() + } + } + } .alert("ERROR_ALERT_TITLE", isPresented: $showErrorAlert) { Button("ERROR_ALERT_CANCEL", role: .cancel) {} } message: { Text(errorMessage) } .task { - do { - if ProcessInfo.processInfo.environment["XCODE_RUNNING_FOR_PREVIEWS"] == "1" { - } else { - try await healthDataInterpreter.setup() - } - } catch { - errorMessage = error.localizedDescription - showErrorAlert = true + if didAttemptInitialization { + return + } + + didAttemptInitialization = true + await initializeInterpreterIfPossible() + } + } + + @MainActor + private func initializeInterpreterIfPossible() async { + if ProcessInfo.processInfo.environment["XCODE_RUNNING_FOR_PREVIEWS"] == "1" { + Self.logger.info("initializeInterpreterIfPossible: skipping (running in previews)") + return + } + + let modelExists = localModelExists() + let sourceExists = localModelSourceExists() + Self.logger.info("initializeInterpreterIfPossible: modelExists=\(modelExists, privacy: .public) sourceExists=\(sourceExists, privacy: .public) modelID=\(Constants.llmModelName, privacy: .public) destination=\(Constants.llmLocalModelDirectory.path, privacy: .public)") + + guard modelExists || sourceExists || Constants.skipLLMLoad else { + Self.logger.info("initializeInterpreterIfPossible: no local model or source found — surfacing download UI") + needsModelDownload = true + if !showModelDownload { + showModelDownload = true + } + return + } + + needsModelDownload = false + + Self.logger.info("initializeInterpreterIfPossible: calling healthDataInterpreter.setup()") + do { + try await healthDataInterpreter.setup() + if Constants.autoOpenTSLMECGSampleOnLaunch { + Self.logger.info("initializeInterpreterIfPossible: setup() returned; auto-opening OpenTSLM ECG sample") + firstPrompt = Constants.openTSLMECGSampleCommand + } else if Constants.autoECGPromptOnLaunch { + Self.logger.info("initializeInterpreterIfPossible: setup() returned; auto-opening ECG prompt") + firstPrompt = Constants.ecgAutoPrompt + } else { + Self.logger.info("initializeInterpreterIfPossible: setup() returned; waiting for user input") } + } catch { + Self.logger.error("initializeInterpreterIfPossible: setup() threw: \(error.localizedDescription, privacy: .public)") + errorMessage = error.localizedDescription + showErrorAlert = true } } + + private func localModelExists() -> Bool { + // Use the same check Spezi/HubApi uses, so we never disagree about whether the model is present. + LLMLocalDownloadManager.modelExist(model: .custom(id: Constants.llmModelName)) + } + + private func localModelSourceExists() -> Bool { + let fileManager = FileManager.default + + if let overridePath = Constants.localModelSourcePathOverride, + fileManager.fileExists(atPath: NSString(string: overridePath).expandingTildeInPath) { + return true + } + + if let bundledLocalModelURL = Bundle.main.resourceURL? + .appendingPathComponent(Constants.localModelBundleSubdirectory, isDirectory: true) { + return fileManager.fileExists(atPath: bundledLocalModelURL.path) + } + + return false + } private var settingsButton: some View { Button( @@ -92,10 +173,41 @@ struct HealthyLLMView: View { } private var loadingChatView: some View { - VStack { - Text("LOADING_CHAT_VIEW") - ProgressView() + VStack(spacing: 12) { + if needsModelDownload { + Image(systemName: "tray.and.arrow.down") + .font(.largeTitle) + Text("Model not found on this device") + .font(.headline) + Text("Repo: \(Constants.llmModelName)") + .font(.caption) + .foregroundStyle(.secondary) + Text("Tap below to download. The model is gated — make sure you have accepted its license on Hugging Face and have HF_TOKEN set.") + .font(.subheadline) + .multilineTextAlignment(.center) + .foregroundStyle(.secondary) + .padding(.horizontal) + Button("Open download") { + showModelDownload = true + } + .buttonStyle(.borderedProminent) + } else { + Text("LOADING_CHAT_VIEW") + .font(.headline) + ProgressView() + Text(healthDataInterpreter.loadingStage.rawValue) + .font(.subheadline) + .foregroundStyle(.secondary) + if !healthDataInterpreter.loadingDetail.isEmpty { + Text(healthDataInterpreter.loadingDetail) + .font(.caption) + .foregroundStyle(.tertiary) + .multilineTextAlignment(.center) + .padding(.horizontal) + } + } } + .padding() } } diff --git a/HealthyLLM/HealthyLLM/Models.swift b/HealthyLLM/HealthyLLM/Models.swift index b45c965..0c71ef9 100644 --- a/HealthyLLM/HealthyLLM/Models.swift +++ b/HealthyLLM/HealthyLLM/Models.swift @@ -22,6 +22,17 @@ struct HealthData: Encodable { let values: [String: [Double]] } +struct ElectrocardiogramData: Encodable { + let startDate: Date + let endDate: Date + let classification: String + let symptomsStatus: String + let averageHeartRate: Double? + let samplingFrequency: Double? + let numberOfVoltageMeasurements: Int + let voltages: [Double] +} + struct WorkoutData: Encodable { let name: String let date: String diff --git a/HealthyLLM/HealthyLLM/OpenTSLM/ECGQACoTDataset.swift b/HealthyLLM/HealthyLLM/OpenTSLM/ECGQACoTDataset.swift new file mode 100644 index 0000000..5954df9 --- /dev/null +++ b/HealthyLLM/HealthyLLM/OpenTSLM/ECGQACoTDataset.swift @@ -0,0 +1,344 @@ +import Foundation + +/// On-device ECG-QA CoT loader mirroring OpenTSLM's ``ECGQACoTQADataset`` + ``ecgqa_cot_loader``. +/// +/// Reads split CSV metadata (question, rationale, ecg_id, …) and lazily attaches PTB-XL waveforms +/// from a sidecar JSON generated by ``scripts/export_ecg_qa_ios_assets.py`` via the Python loader. +public final class ECGQACoTDataset { + + public enum Split: String { + case train + case validation + case test + + var csvBaseName: String { + switch self { + case .train: return "ecg_qa_cot_train" + case .validation: return "ecg_qa_cot_val" + case .test: return "ecg_qa_cot_test" + } + } + } + + public struct RowMetadata { + public let split: Split + public let rowIndex: Int + public let ecgId: Int + public let templateId: Int + public let questionType: String + public let question: String + public let clinicalContext: String + public let correctAnswer: String + } + + private struct Row { + let ecgId: Int + let templateId: Int + let questionType: String + let question: String + let clinicalContext: String + let correctAnswer: String + let rationale: String + } + + private struct WaveformEntry: Decodable { + let timeSeries: [[Float]] + let timeSeriesText: [String] + + enum CodingKeys: String, CodingKey { + case timeSeries = "time_series" + case timeSeriesText = "time_series_text" + } + } + + private let split: Split + private let rows: [Row] + private let waveformsDirectory: URL + private var waveformCache: [String: WaveformEntry] = [:] + private let templateAnswers: [Int: [String]] + + public init( + csvURL: URL, + waveformsDirectory: URL, + templateAnswersURL: URL? = nil, + split: Split = .test, + maxRows: Int = 5000 + ) throws { + self.split = split + self.waveformsDirectory = waveformsDirectory + self.rows = try Self.readRows(from: csvURL, maxRows: maxRows) + if let templateAnswersURL { + self.templateAnswers = try Self.readTemplateAnswers(from: templateAnswersURL) + } else { + self.templateAnswers = [:] + } + } + + public var count: Int { rows.count } + + public func rowMetadata(at index: Int) -> RowMetadata { + precondition(index >= 0 && index < rows.count, "row index out of range") + let row = rows[index] + return RowMetadata( + split: split, + rowIndex: index, + ecgId: row.ecgId, + templateId: row.templateId, + questionType: row.questionType, + question: row.question, + clinicalContext: row.clinicalContext, + correctAnswer: row.correctAnswer + ) + } + + public func sample(at index: Int) throws -> OpenTSLMSPSample { + precondition(index >= 0 && index < rows.count, "sample index out of range") + let row = rows[index] + let waveform = try loadWaveform(for: row.ecgId) + + return OpenTSLMSPSample( + prePrompt: Self.prePrompt(clinicalContext: row.clinicalContext, question: row.question), + timeSeriesText: waveform.timeSeriesText, + timeSeries: waveform.timeSeries, + postPrompt: Self.postPrompt(possibleAnswers: templateAnswers[row.templateId] ?? []), + label: row.correctAnswer, + answer: row.rationale + ) + } + + private func loadWaveform(for ecgId: Int) throws -> WaveformEntry { + let key = String(ecgId) + if let cached = waveformCache[key] { + return cached + } + + guard let entry = loadWaveformFile(for: key) else { + throw NSError( + domain: "ECGQACoTDataset", + code: 5, + userInfo: [ + NSLocalizedDescriptionKey: + "Missing PTB-XL waveform for ecg_id \(ecgId) under \(waveformsDirectory.path). " + + "Run scripts/export_ecg_qa_ios_assets.py to regenerate loader waveforms.", + ] + ) + } + waveformCache[key] = entry + return entry + } + + private func loadWaveformFile(for ecgId: String) -> WaveformEntry? { + let url = waveformsDirectory.appendingPathComponent("\(ecgId).json") + guard FileManager.default.fileExists(atPath: url.path), + let data = try? Data(contentsOf: url), + let object = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let rawSeries = object["time_series"] as? [[Any]], + let rawText = object["time_series_text"] as? [String] + else { + return nil + } + + guard let series = try? rawSeries.map({ lead in + try lead.map { value -> Float in + if let f = value as? Float { return f } + if let d = value as? Double { return Float(d) } + if let i = value as? Int { return Float(i) } + throw NSError(domain: "ECGQACoTDataset", code: 7, userInfo: [NSLocalizedDescriptionKey: "Invalid waveform value"]) + } + }) else { + return nil + } + + return WaveformEntry(timeSeries: series, timeSeriesText: rawText) + } + + private static func readTemplateAnswers(from url: URL) throws -> [Int: [String]] { + let data = try Data(contentsOf: url) + let decoded = try JSONSerialization.jsonObject(with: data) + guard let object = decoded as? [String: Any], + let rawTemplates = object["template_answers"] as? [String: [Any]] + else { + return [:] + } + + var answers: [Int: [String]] = [:] + for (templateId, values) in rawTemplates { + guard let intId = Int(templateId) else { continue } + answers[intId] = values.compactMap { $0 as? String } + } + return answers + } + + private static func readRows(from csvURL: URL, maxRows: Int) throws -> [Row] { + guard maxRows > 0 else { + throw NSError(domain: "ECGQACoTDataset", code: 8, userInfo: [NSLocalizedDescriptionKey: "maxRows must be > 0"]) + } + + let handle = try FileHandle(forReadingFrom: csvURL) + let prefix = try handle.read(upToCount: 20 * 1024 * 1024) ?? Data() + let csvPrefix = String(decoding: prefix, as: UTF8.self) + let safePrefix = csvPrefix.prefix(upTo: csvPrefix.lastIndex(of: "\n") ?? csvPrefix.endIndex) + + let records = parseCSV(String(safePrefix)) + guard let header = records.first else { + throw NSError(domain: "ECGQACoTDataset", code: 1, userInfo: [NSLocalizedDescriptionKey: "CSV has no header"]) + } + + let columnIndex = Dictionary(uniqueKeysWithValues: header.enumerated().map { ($1, $0) }) + guard let questionIdx = columnIndex["question"], + let answerIdx = columnIndex["answer"], + let templateIdx = columnIndex["template_id"], + let questionTypeIdx = columnIndex["question_type"], + let ecgIdIdx = columnIndex["ecg_id"], + let rationaleIdx = columnIndex["rationale"] + else { + throw NSError(domain: "ECGQACoTDataset", code: 2, userInfo: [NSLocalizedDescriptionKey: "CSV missing required ECG-QA CoT columns"]) + } + + let clinicalIdx = columnIndex["clinical_context"] + + var rows: [Row] = [] + rows.reserveCapacity(maxRows) + + for record in records.dropFirst() where rows.count < maxRows { + guard !record.isEmpty else { continue } + guard questionIdx < record.count, + answerIdx < record.count, + templateIdx < record.count, + questionTypeIdx < record.count, + ecgIdIdx < record.count, + rationaleIdx < record.count + else { continue } + + let ecgId = try parseECGId(record[ecgIdIdx]) + let templateId = Int(record[templateIdx].trimmingCharacters(in: .whitespacesAndNewlines)) ?? 0 + let clinicalContext = (clinicalIdx != nil && clinicalIdx! < record.count && !record[clinicalIdx!].isEmpty) + ? record[clinicalIdx!] + : "12-lead ECG recording." + + rows.append( + Row( + ecgId: ecgId, + templateId: templateId, + questionType: record[questionTypeIdx], + question: record[questionIdx], + clinicalContext: clinicalContext, + correctAnswer: record[answerIdx], + rationale: record[rationaleIdx] + ) + ) + } + + guard !rows.isEmpty else { + throw NSError(domain: "ECGQACoTDataset", code: 3, userInfo: [NSLocalizedDescriptionKey: "CSV has no ECG-QA CoT rows"]) + } + + return rows + } + + private static func parseECGId(_ raw: String) throws -> Int { + let cleaned = raw.trimmingCharacters(in: .whitespacesAndNewlines).trimmingCharacters(in: CharacterSet(charactersIn: "[]")) + guard let value = Int(cleaned) else { + throw NSError(domain: "ECGQACoTDataset", code: 4, userInfo: [NSLocalizedDescriptionKey: "Invalid ecg_id '\(raw)'"]) + } + return value + } + + // Verbatim from OpenTSLM `ECGQACoTQADataset.py:_get_pre_prompt`. + private static func prePrompt(clinicalContext: String, question: String) -> String { + """ + You are an expert cardiologist analyzing an ECG (electrocardiogram). + + Clinical Context: \(clinicalContext) + + Your task is to examine the ECG signal and answer the following medical question: + + Question: \(question) + + Instructions: + - Begin by analyzing the time series without assuming a specific answer. + - Think step-by-step about what the observed patterns suggest regarding the cardiac condition. + - Write your rationale as a single, natural paragraph — do not use bullet points, numbered steps, or section headings. + - Do **not** mention any final answer until the very end. + - Consider the ECG morphology, intervals, and any abnormalities that relate to the question. + """ + } + + // Verbatim from OpenTSLM `ECGQACoTQADataset.py:_get_post_prompt`. + private static func postPrompt(possibleAnswers: [String]) -> String { + if possibleAnswers.isEmpty { + return """ + Based on your analysis of the ECG data, provide your answer. + Make sure that your last word is the answer. You MUST end your response with "Answer: " + """ + } + + let answersText = possibleAnswers.joined(separator: ", ") + return """ + Based on your analysis of the ECG data, select your answer from the following options: + \(answersText) + + - Make sure that your last word is the answer. You MUST end your response with "Answer: " + """ + } + + private static func parseCSV(_ csv: String) -> [[String]] { + var records: [[String]] = [] + var row: [String] = [] + var field = "" + var insideQuotes = false + + var i = csv.startIndex + while i < csv.endIndex { + let ch = csv[i] + if ch == "\"" { + let nextIndex = csv.index(after: i) + if insideQuotes && nextIndex < csv.endIndex && csv[nextIndex] == "\"" { + field.append("\"") + i = csv.index(after: nextIndex) + continue + } + insideQuotes.toggle() + i = nextIndex + continue + } + + if ch == "," && !insideQuotes { + row.append(field) + field.removeAll(keepingCapacity: true) + i = csv.index(after: i) + continue + } + + if (ch == "\n" || ch == "\r") && !insideQuotes { + row.append(field) + field.removeAll(keepingCapacity: true) + if !row.isEmpty { + records.append(row) + } + row.removeAll(keepingCapacity: true) + + if ch == "\r" { + let nextIndex = csv.index(after: i) + if nextIndex < csv.endIndex && csv[nextIndex] == "\n" { + i = csv.index(after: nextIndex) + continue + } + } + + i = csv.index(after: i) + continue + } + + field.append(ch) + i = csv.index(after: i) + } + + if !field.isEmpty || !row.isEmpty { + row.append(field) + records.append(row) + } + + return records + } +} diff --git a/HealthyLLM/HealthyLLM/OpenTSLM/EmbeddingLlamaModel.swift b/HealthyLLM/HealthyLLM/OpenTSLM/EmbeddingLlamaModel.swift new file mode 100644 index 0000000..1a17778 --- /dev/null +++ b/HealthyLLM/HealthyLLM/OpenTSLM/EmbeddingLlamaModel.swift @@ -0,0 +1,493 @@ +// +// EmbeddingLlamaModel.swift +// +// Vendored copy of mlx-swift-examples `Libraries/MLXLLM/Models/Llama.swift` +// (Copyright © 2024 Apple Inc.), adapted to accept pre-computed input embeddings +// (`inputs_embeds`) so OpenTSLM soft-prompt time-series embeddings can be fed +// directly into the Llama decoder. +// +// Why a full vendored copy rather than a subclass/extension: the upstream +// building blocks (`Attention`, `MLP`, `TransformerBlock`, `LlamaModelInner`) +// and `LlamaConfiguration`'s stored properties are `internal` to MLXLLM and +// therefore unreachable from this module. Copying is the only way to inject +// embeddings into the forward pass from app code. +// +// Differences from upstream (kept intentionally minimal): +// - Types renamed with an `EmbeddingLlama*` prefix to avoid colliding with the +// real `MLXLLM.LlamaModel` (which is still loaded for non-OpenTSLM chat). +// - `EmbeddingLlamaConfiguration` is a verbatim copy of `LlamaConfiguration` +// (its fields are internal upstream, so we need our own readable copy). +// - The inner/outer forward gains an optional `inputEmbedding` that, when +// present, replaces the token-embedding lookup. The token-only +// `callAsFunction(_:cache:)` (the `LLMModel` requirement used by normal +// chat / `TokenIterator`) is unchanged. +// - Conforms to the app's `EmbeddingPrimedLanguageModel` so +// `MLXEmbeddingGenerator` can drive embedding-primed decoding. +// - The unused `computeBaseFrequency` free function from upstream is omitted. +// +// SPDX-License-Identifier: MIT +// + +import Foundation +import MLX +import MLXFast +import MLXLLM +import MLXLMCommon +import MLXNN + +// MARK: - Rotary embedding (verbatim from upstream, retyped to our config) + +private class DynamicNTKScalingRoPE: Module { + let dims: Int + let maxPositionEmbeddings: Int + let traditional: Bool + var base: Float? + let scale: Float + let ropeType: String + let ropeScaling: [String: StringOrNumber]? + var freqs: MLXArray? + + init( + dims: Int, + maxPositionEmbeddings: Int?, + traditional: Bool = false, + base: Float = 10000, + scale: Float = 1.0, + ropeType: String = "default", + ropeScaling: [String: StringOrNumber]? = nil + ) { + self.dims = dims + self.maxPositionEmbeddings = maxPositionEmbeddings ?? 2048 + self.traditional = traditional + self.base = base + self.scale = scale + self.ropeType = ropeType + self.ropeScaling = ropeScaling + super.init() + computeFreqs() + } + + private func computeFreqs() { + if ropeType != "llama3" { + freqs = nil + return + } + + guard let ropeScaling = ropeScaling, + case .float(let factor) = ropeScaling["factor"], + case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0), + case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0), + case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"] + ?? .float(8192), + let base + else { + freqs = nil + return + } + + let lowFreqWavelen = oldContextLen / lowFreqFactor + let highFreqWavelen = oldContextLen / highFreqFactor + + let indices = MLXArray(stride(from: 0, to: dims, by: 2)) + var frequencies = MLX.pow(base, indices / Float(dims)) + let wavelens = 2 * Float.pi * frequencies + + frequencies = MLX.where( + wavelens .> MLXArray(lowFreqWavelen), frequencies * factor, frequencies) + let isMediumFreq = MLX.logicalAnd( + wavelens .> MLXArray(highFreqWavelen), + wavelens .< MLXArray(lowFreqWavelen) + ) + let smoothFactors = + (oldContextLen / wavelens - lowFreqFactor) / (highFreqFactor - lowFreqFactor) + let smoothFreqs = frequencies / ((1 - smoothFactors) / factor + smoothFactors) + + freqs = MLX.where(isMediumFreq, smoothFreqs, frequencies) + self.base = nil + } + + func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { + MLXFast.RoPE( + x, + dimensions: dims, + traditional: traditional, + base: base, + scale: scale, + offset: offset, + freqs: freqs + ) + } +} + +// MARK: - Attention / MLP / block (verbatim from upstream, retyped to our config) + +private class Attention: Module { + + let args: EmbeddingLlamaConfiguration + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: DynamicNTKScalingRoPE + + init(_ args: EmbeddingLlamaConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.resolvedHeadDimensions + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias) + + self.rope = DynamicNTKScalingRoPE( + dims: headDim, + maxPositionEmbeddings: args.maxPositionEmbeddings, + traditional: args.ropeTraditional, + base: args.ropeTheta, + scale: 1.0, + ropeType: { + if case .string(let value) = args.ropeScaling?["type"] { + return value + } else { + return "default" + } + }(), + ropeScaling: args.ropeScaling) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // Prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + (keys, values) = cache.update(keys: keys, values: values) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + init(_ args: EmbeddingLlamaConfiguration) { + self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) + self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias) + self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let activation = silu(gate(x)) + return down(activation * up(x)) + } +} + +private class TransformerBlock: Module { + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "mlp") var mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + init(_ args: EmbeddingLlamaConfiguration) { + self._attention.wrappedValue = Attention(args) + self._mlp.wrappedValue = MLP(args) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } +} + +// MARK: - Inner model (MODIFIED: accepts inputEmbedding) + +private class EmbeddingLlamaModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + let layers: [TransformerBlock] + let norm: RMSNorm + + init(_ args: EmbeddingLlamaConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + /// When `inputEmbedding` is supplied it is used directly as the hidden state, + /// skipping the token-embedding lookup. Llama (unlike Gemma) does not scale + /// the embeddings, so there is nothing else to reconcile here. + func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> MLXArray { + var h: MLXArray + if let inputEmbedding { + h = inputEmbedding + } else if let inputs { + h = embedTokens(inputs) + } else { + fatalError("EmbeddingLlamaModelInner requires either inputs or inputEmbedding") + } + + let mask: MLXArray? = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } +} + +// MARK: - Outer model (MODIFIED: embedding-primed overload + protocol conformance) + +/// Drop-in for `MLXLLM.LlamaModel` whose forward pass can be primed with +/// pre-computed embeddings. Parameter graph is identical to upstream `LlamaModel` +/// (same `@ModuleInfo` keys, same nesting), so the factory's weight loading, +/// quantization, and `verify` all work unchanged. +class EmbeddingLlamaModel: Module, LLMModel, KVCacheDimensionProvider { + + let vocabularySize: Int + let kvHeads: [Int] + + fileprivate let model: EmbeddingLlamaModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + init(_ args: EmbeddingLlamaConfiguration) { + self.vocabularySize = args.vocabularySize + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + self.model = EmbeddingLlamaModelInner(args) + if !args.tieWordEmbeddings { + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + } + + private func forward(_ inputs: MLXArray?, cache: [KVCache]?, inputEmbedding: MLXArray?) + -> MLXArray + { + let out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + if let lmHead { + return lmHead(out) + } else { + return model.embedTokens.asLinear(out) + } + } + + /// `LLMModel` requirement — token-only path used by normal chat / `TokenIterator`. + /// Unchanged behaviour vs. upstream `LlamaModel`. + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + forward(inputs, cache: cache, inputEmbedding: nil) + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + // Remove unused precomputed rotary frequencies + weights.filter { + !$0.key.contains("self_attn.rotary_emb.inv_freq") + } + } + + /// Token-embedding lookup, exposed so callers can build interleaved + /// soft-prompt sequences (text token embeddings + projected time-series + /// embeddings) before priming the decoder via ``inputEmbedding``. + /// `ids` is an integer index array, e.g. shape `[B, T]`; result is `[B, T, hidden]`. + func tokenEmbeddings(_ ids: MLXArray) -> MLXArray { + model.embedTokens(ids) + } +} + +// MARK: - LoRA + +extension EmbeddingLlamaModel: LoRAModel { + func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} + +// MARK: - Embedding-primed decoding (app protocol from MLXEmbeddingGenerator) + +extension EmbeddingLlamaModel: EmbeddingPrimedLanguageModel { + func makeCache() -> [KVCache] { + newCache(parameters: nil) + } + + func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache], inputEmbedding: MLXArray? + ) throws -> MLXArray { + forward(inputs, cache: cache, inputEmbedding: inputEmbedding) + } +} + +// MARK: - Configuration (verbatim copy of upstream LlamaConfiguration) + +struct EmbeddingLlamaConfiguration: Codable, Sendable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var headDimensions: Int? + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var maxPositionEmbeddings: Int? + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var ropeScaling: [String: StringOrNumber]? + var tieWordEmbeddings: Bool = true + var attentionBias: Bool = false + var mlpBias: Bool = false + + var resolvedHeadDimensions: Int { + headDimensions ?? (hiddenSize / attentionHeads) + } + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case headDimensions = "head_dim" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case maxPositionEmbeddings = "max_position_embeddings" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case tieWordEmbeddings = "tie_word_embeddings" + case attentionBias = "attention_bias" + case mlpBias = "mlp_bias" + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) + hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) + intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) + attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) + headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) + rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) + vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) + kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads + maxPositionEmbeddings = try container.decodeIfPresent( + Int.self, forKey: .maxPositionEmbeddings) + if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { + self.ropeTheta = ropeTheta + } + if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) + { + self.ropeTraditional = ropeTraditional + } + ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: .ropeScaling) + if let tieWordEmbeddings = try container.decodeIfPresent( + Bool.self, forKey: .tieWordEmbeddings) + { + self.tieWordEmbeddings = tieWordEmbeddings + } + if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { + self.attentionBias = attentionBias + } + if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { + self.mlpBias = mlpBias + } + + if let ropeScaling { + if ropeScaling["factor"] == nil { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: "rope_scaling must contain 'factor'") + } + if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] { + if case .string = ropeType { + let options = [ + StringOrNumber.string("linear"), StringOrNumber.string("dynamic"), + StringOrNumber.string("llama3"), + ] + if !options.contains(ropeType) { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: + "rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'" + ) + } + } + } else { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: "rope_scaling must contain either 'type' or 'rope_type'") + } + } + } +} + +// MARK: - Factory registration + +/// Overrides the `"llama"` (and `"mistral"`) entries in the shared MLX model-type +/// registry so SpeziLLMLocal's `LLMModelFactory.shared.loadContainer(...)` +/// instantiates `EmbeddingLlamaModel` instead of the stock `LlamaModel`. +/// +/// Call once at app/Spezi configuration time, *before* the first model load. +/// The closure mirrors the upstream private `create(_:_:)` helper: the URL it +/// receives is the `config.json` file itself (not the directory), and it only +/// builds the bare model — the factory loads/quantizes weights afterwards. +enum EmbeddingLlamaModelRegistration { + static func register() { + let creator: @Sendable (URL) throws -> any LanguageModel = { url in + let config = try JSONDecoder().decode( + EmbeddingLlamaConfiguration.self, from: Data(contentsOf: url)) + return EmbeddingLlamaModel(config) + } + LLMModelFactory.shared.typeRegistry.registerModelType("llama", creator: creator) + LLMModelFactory.shared.typeRegistry.registerModelType("mistral", creator: creator) + } +} diff --git a/HealthyLLM/HealthyLLM/OpenTSLM/MLPProjector.swift b/HealthyLLM/HealthyLLM/OpenTSLM/MLPProjector.swift new file mode 100644 index 0000000..ef2d78f --- /dev/null +++ b/HealthyLLM/HealthyLLM/OpenTSLM/MLPProjector.swift @@ -0,0 +1,28 @@ +import Foundation +import MLX +import MLXNN + +/// 1:1 Swift port of Python MLX `MLPProjector`. +public class MLPProjector: Module, UnaryLayer { + + @ModuleInfo(key: "norm") var norm: LayerNorm + @ModuleInfo(key: "linear") var linear: Linear + + public init(inputDim: Int = 128, outputDim: Int = 2048) { + _norm.wrappedValue = LayerNorm(dimensions: inputDim) + _linear.wrappedValue = Linear(inputDim, outputDim) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + gelu(linear(norm(x))) + } +} + +extension MLPProjector { + + public func loadWeights(from url: URL) throws { + let arrays = try loadArrays(url: url) + let params = ModuleParameters.unflattened(arrays) + try update(parameters: params, verify: .noUnusedKeys) + } +} \ No newline at end of file diff --git a/HealthyLLM/HealthyLLM/OpenTSLM/MLXEmbeddingGenerator.swift b/HealthyLLM/HealthyLLM/OpenTSLM/MLXEmbeddingGenerator.swift new file mode 100644 index 0000000..7f67180 --- /dev/null +++ b/HealthyLLM/HealthyLLM/OpenTSLM/MLXEmbeddingGenerator.swift @@ -0,0 +1,84 @@ +// +// MLXEmbeddingGenerator.swift +// Custom MLX generation loop that accepts pre-computed embeddings (inputs_embeds) +// +// Note: This is a best-effort implementation that targets the MLX/MLXLLM API surface. +// If signatures differ in your installed MLX version, adapt calls accordingly. +// + +import Foundation +import MLX + +public enum MLXEmbeddingGeneratorError: Error { + case modelDoesNotSupportEmbeddingGeneration + case generationFailed(String) +} + +/// A small protocol describing the minimum functionality needed for embedding-primed generation. +/// This keeps the implementation decoupled from a specific MLXLLM version. +public protocol EmbeddingPrimedLanguageModel { + associatedtype Cache + + func makeCache() -> Cache + func callAsFunction(_ inputs: MLXArray?, cache: Cache, inputEmbedding: MLXArray?) throws -> MLXArray +} + +/// Minimal helper implementing a generation loop that uses pre-computed `inputs_embeds`. +/// The model supplies logits; token sampling is injected so this stays compatible with the +/// exact MLX version available in the app target. +public final class MLXEmbeddingGenerator { + private let model: Model + private let eosTokenIds: Set + private let tokenSampler: (MLXArray) throws -> Int + private let decodeTokens: ([Int]) throws -> String + + public init( + model: Model, + eosTokenIds: Set, + tokenSampler: @escaping (MLXArray) throws -> Int, + decodeTokens: @escaping ([Int]) throws -> String + ) { + self.model = model + self.eosTokenIds = eosTokenIds + self.tokenSampler = tokenSampler + self.decodeTokens = decodeTokens + } + + /// Generate text starting from pre-computed embeddings. + public func generate( + inputsEmbeds: MLXArray, + maxNewTokens: Int = 128, + temperature: Float = 1.0 + ) throws -> String { + let cache = model.makeCache() + var logits = try model.callAsFunction(nil, cache: cache, inputEmbedding: inputsEmbeds) + var generatedTokenIds: [Int] = [] + + for _ in 0..= 2 else { + throw MLXEmbeddingGeneratorError.generationFailed("Unexpected logits ndim: \(logits.ndim)") + } + + let seqLen = Int(logits.dim(1)) + guard seqLen > 0 else { + throw MLXEmbeddingGeneratorError.generationFailed("Logits sequence length is zero") + } + + let lastLogitsSlice = logits[0 ..< 1, seqLen - 1 ..< seqLen] + let lastLogits = lastLogitsSlice / Float(temperature) + let nextToken = try tokenSampler(lastLogits) + generatedTokenIds.append(nextToken) + + if eosTokenIds.contains(nextToken) { + break + } + + // Integer index array — the embedding lookup requires integer ids, + // not the float array produced by `MLXArray(converting:)`. + let inputIds = MLXArray([Int32(nextToken)], [1, 1]) + logits = try model.callAsFunction(inputIds, cache: cache, inputEmbedding: nil) + } + + return try decodeTokens(generatedTokenIds) + } +} diff --git a/HealthyLLM/HealthyLLM/OpenTSLM/OpenTSLMInferenceService.swift b/HealthyLLM/HealthyLLM/OpenTSLM/OpenTSLMInferenceService.swift new file mode 100644 index 0000000..9d67ee3 --- /dev/null +++ b/HealthyLLM/HealthyLLM/OpenTSLM/OpenTSLMInferenceService.swift @@ -0,0 +1,1068 @@ +// +// This source file is part of the HealthyLLM based on the Stanford Spezi Template Application project +// +// SPDX-FileCopyrightText: 2026 Stanford University +// +// SPDX-License-Identifier: MIT +// + +import Foundation +import MLX +import OSLog +import Spezi +import SpeziLLM +import SpeziLLMLocal + +@Observable +class OpenTSLMInferenceService: DefaultInitializable, Module, EnvironmentAccessible { + @ObservationIgnored private let logger = Logger(subsystem: "HealthyLLM", category: "OpenTSLMInferenceService") + + required init() { } + + func runSleepSampleInference( + split: SleepEDFDataset.Split = .test, + sampleIndex: Int = 0, + llmRunner: LLMRunner? = nil, + llmSession: LLMLocalSession? = nil + ) async throws -> String { + guard let csvURL = resolveAssetURL( + overridePath: Constants.openTSLMSleepCSVPath, + bundledName: Constants.openTSLMSleepCSVName, + fileExtension: "csv" + ) else { + throw NSError(domain: "OpenTSLMInferenceService", code: 1, userInfo: [NSLocalizedDescriptionKey: "Missing sleep CSV in bundle/OpenTSLM or override path"]) + } + + guard let encoderURL = resolveAssetURL( + overridePath: Constants.openTSLMEncoderCheckpointPath, + bundledName: Constants.openTSLMEncoderCheckpointName, + fileExtension: "safetensors" + ) else { + throw NSError(domain: "OpenTSLMInferenceService", code: 2, userInfo: [NSLocalizedDescriptionKey: "Missing encoder checkpoint in bundle/OpenTSLM or override path"]) + } + + guard let projectorURL = resolveAssetURL( + overridePath: Constants.openTSLMProjectorCheckpointPath, + bundledName: Constants.openTSLMProjectorCheckpointName, + fileExtension: "safetensors" + ) else { + throw NSError(domain: "OpenTSLMInferenceService", code: 3, userInfo: [NSLocalizedDescriptionKey: "Missing projector checkpoint in bundle/OpenTSLM or override path"]) + } + + logger.info("OpenTSLM assets: csv=\(csvURL.path), encoder=\(encoderURL.path), projector=\(projectorURL.path)") + + let dataset = try SleepEDFDataset(csvURL: csvURL, split: split) + guard dataset.count > 0 else { + throw NSError(domain: "OpenTSLMInferenceService", code: 4, userInfo: [NSLocalizedDescriptionKey: "Selected split has no samples"]) + } + + let safeIndex = min(max(sampleIndex, 0), dataset.count - 1) + let sample = cappedSample(dataset.sample(at: safeIndex)) + + let projected: [MLXArray] + do { + let pipeline = OpenTSLMSPPipeline(hiddenSize: 2048) + try pipeline.loadWeights(encoderURL: encoderURL, projectorURL: projectorURL) + projected = pipeline.projectSample(sample) + } + + guard let first = projected.first else { + throw NSError(domain: "OpenTSLMInferenceService", code: 5, userInfo: [NSLocalizedDescriptionKey: "Projection returned no tensors"]) + } + + eval(first) + GPU.clearCache() + logger.info("OpenTSLM encoder projection done: split=\(split.rawValue), sample=\(safeIndex)") + + if let llmRunner = llmRunner, let llmSession = llmSession { + let loraApplied = await applyLoRAIfAvailable(on: llmSession) + + GPU.clearCache() + + if Constants.openTSLMRunSampleLLMGeneration { + let openTSLMLLM = OpenTSLMLLM(llmRunner: llmRunner, session: llmSession) + let generatedText = try await openTSLMLLM.generate( + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + timeSeriesEmbeddings: projected, + postPrompt: sample.postPrompt, + maxTokens: 200 + ) + + let outputText = """ + **LLM-Generated Analysis:** + + \(generatedText) + + --- + Ground-truth label: \(sample.label) + Ground-truth answer: \(sample.answer) + """ + + return formatSampleReport( + title: "Sleep-EDF sample inference with LLM generation", + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + postPrompt: sample.postPrompt, + label: sample.label, + answer: outputText, + extraLines: openTSLMSampleExtraLines( + split: split, + sampleIndex: safeIndex, + sample: sample, + embeddingsShape: first.shape, + loraApplied: loraApplied + ) + ) + } + + let outputText = """ + **OpenTSLM encoder + LoRA (on-device decode skipped to avoid OOM)** + + Projected time-series embeddings: \(first.shape) + LoRA applied to Llama: \(loraApplied ? "yes" : "no") + Ground-truth label: \(sample.label) + Ground-truth answer: \(sample.answer) + + Set `HEALTHYLLM_OPEN_TSLM_RUN_LLM=1` in the scheme to attempt a short LLM decode (may exceed device memory). + """ + + return formatSampleReport( + title: "Sleep-EDF sample inference (encoder + LoRA)", + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + postPrompt: sample.postPrompt, + label: sample.label, + answer: outputText, + extraLines: openTSLMSampleExtraLines( + split: split, + sampleIndex: safeIndex, + sample: sample, + embeddingsShape: first.shape, + loraApplied: loraApplied + ) + ) + } else { + // Fallback: just describe the embeddings + let outputText = """ + **Embeddings Computed Successfully** + + The time series embeddings were computed (\(first.shape)) but no LLM session was provided for generation. + + To enable LLM generation with embeddings: + 1. Ensure the HealthDataInterpreter is initialized with a valid LLM session + 2. Pass the llmRunner and llmSession parameters to this method + + Ground-truth label: \(sample.label) + Ground-truth answer: \(sample.answer) + """ + + return formatSampleReport( + title: "Sleep-EDF sample inference (embeddings only)", + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + postPrompt: sample.postPrompt, + label: sample.label, + answer: outputText, + extraLines: [ + "split: \(split.rawValue)", + "sample_index: \(safeIndex)", + "series_count: \(sample.timeSeries.count)", + "embeddings_shape: \(first.shape)", + "llm_integration: no", + ] + ) + } + } + + func runECGSampleInference( + split: ECGQACoTDataset.Split = .test, + sampleIndex: Int = 0, + llmRunner: LLMRunner? = nil, + llmSession: LLMLocalSession? = nil + ) async throws -> String { + let loaded = try loadECGQACoTSample(split: split, sampleIndex: sampleIndex) + let sample = cappedSample(loaded.sample) + var metadata = loaded.metadata + metadata = ECGQACoTSampleMetadata( + source: metadata.source, + loader: metadata.loader, + split: metadata.split, + sampleIndex: metadata.sampleIndex, + ecgId: metadata.ecgId, + templateId: metadata.templateId, + questionType: metadata.questionType, + seriesCount: sample.timeSeries.count, + samplesPerLead: sample.timeSeries.first?.count ?? 0 + ) + + // Load ECG encoder + projector (per-task checkpoints — sleep weights would + // produce a meaningless projection for ECG inputs). + let projected = try loadPipelineAndProject( + sample, + encoderName: Constants.openTSLMECGEncoderCheckpointName, + projectorName: Constants.openTSLMECGProjectorCheckpointName + ) + guard let first = projected.first else { + throw NSError(domain: "OpenTSLMInferenceService", code: 5, userInfo: [NSLocalizedDescriptionKey: "Projection returned no tensors"]) + } + + GPU.clearCache() + + guard Constants.openTSLMRunSampleLLMGeneration, + let llmRunner, + let llmSession + else { + let outputText = """ + **OpenTSLM ECG encoder (Llama decode skipped)** + + Projected time-series embeddings: \(first.shape) + Dataset answer: \(sample.answer) + + Set `HEALTHYLLM_OPEN_TSLM_RUN_LLM=1` (and do not set `HEALTHYLLM_SKIP_LLM_LOAD=1`) to run Llama decode — likely OOM on physical iPhone. + """ + + return formatSampleReport( + title: "ECG-QA CoT sample inference (encoder only)", + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + postPrompt: sample.postPrompt, + label: sample.label, + answer: outputText, + extraLines: ecgQACoTSampleExtraLines( + metadata: metadata, + embeddingsShape: first.shape, + loraApplied: false + ) + [ + "llm_model: \(Constants.llmModelName)", + "llm_integration: encoder-only", + "llama_loaded: \(Constants.skipLLMLoad ? "no" : "yes")", + ] + ) + } + + // Llama loaded + generation requested + let loraApplied = await applyLoRAIfAvailable( + on: llmSession, + checkpointName: Constants.openTSLMECGLoRACheckpointName + ) + GPU.clearCache() + + let openTSLMLLM = OpenTSLMLLM(llmRunner: llmRunner, session: llmSession) + let generatedText = try await openTSLMLLM.generate( + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + timeSeriesEmbeddings: projected, + postPrompt: sample.postPrompt, + maxTokens: 200 + ) + + let outputText = """ + **LLM-Generated ECG Analysis:** + + \(generatedText) + + --- + Dataset answer: \(sample.answer) + """ + + return formatSampleReport( + title: "ECG-QA CoT sample inference with LLM generation", + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + postPrompt: sample.postPrompt, + label: sample.label, + answer: outputText, + extraLines: ecgQACoTSampleExtraLines( + metadata: metadata, + embeddingsShape: first.shape, + loraApplied: loraApplied + ) + [ + "llm_model: \(Constants.llmModelName)", + "llm_integration: generate", + ] + ) + } + + /// Runs OpenTSLM-SP on a real (e.g. HealthKit) ECG recording and returns the model's + /// analysis text. Unlike ``runECGSampleInference`` (a debug command using a hardcoded/JSON + /// sample and emitting a verbose report), this uses the supplied recording, requires a live + /// LLM session, and returns a clean user-facing answer. + func runECGInference( + voltages: [Double], + samplingFrequency: Double, + classification: String?, + symptomsStatus: String?, + averageHeartRate: Double?, + llmRunner: LLMRunner?, + llmSession: LLMLocalSession? + ) async throws -> String { + guard let llmRunner, let llmSession else { + throw NSError( + domain: "OpenTSLMInferenceService", code: 7, + userInfo: [NSLocalizedDescriptionKey: "An LLM session is required to analyze the ECG."]) + } + guard !voltages.isEmpty else { + throw NSError( + domain: "OpenTSLMInferenceService", code: 8, + userInfo: [NSLocalizedDescriptionKey: "The ECG recording has no voltage samples."]) + } + + logger.info("runECGInference: start voltages=\(voltages.count, privacy: .public) freq=\(samplingFrequency, privacy: .public)") + let ecg = ECGSample( + source: .healthkit, + samplingFrequency: samplingFrequency, + classification: classification, + symptomsStatus: symptomsStatus, + averageHeartRate: averageHeartRate, + voltages: voltages + ) + let sample = cappedSample(makeOpenTSLMSample(from: ecg)) + let projected = try loadPipelineAndProject( + sample, + encoderName: Constants.openTSLMECGEncoderCheckpointName, + projectorName: Constants.openTSLMECGProjectorCheckpointName + ) + logger.info("runECGInference: projected series=\(projected.count, privacy: .public) shape0=\(projected.first?.shape ?? [], privacy: .public)") + + let loraApplied = await applyLoRAIfAvailable( + on: llmSession, + checkpointName: Constants.openTSLMECGLoRACheckpointName + ) + GPU.clearCache() + logger.info("runECGInference: loraApplied=\(loraApplied, privacy: .public); calling generate") + + let openTSLMLLM = OpenTSLMLLM(llmRunner: llmRunner, session: llmSession) + let analysis = try await openTSLMLLM.generate( + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + timeSeriesEmbeddings: projected, + postPrompt: sample.postPrompt, + maxTokens: 200 + ) + logger.info("runECGInference: generate returned \(analysis.count, privacy: .public) chars") + let trimmed = analysis.trimmingCharacters(in: .whitespacesAndNewlines) + return trimmed.isEmpty ? "The ECG model did not produce an analysis." : trimmed + } + + /// Resolves the encoder/projector checkpoints, loads the SP pipeline, and projects the + /// sample's time series to LLM-hidden-size embeddings. + /// Defaults to the sleep checkpoints; the ECG paths pass the `.ecg` names. + private func loadPipelineAndProject( + _ sample: OpenTSLMSPSample, + encoderName: String = Constants.openTSLMEncoderCheckpointName, + projectorName: String = Constants.openTSLMProjectorCheckpointName + ) throws -> [MLXArray] { + guard let encoderURL = resolveAssetURL( + overridePath: Constants.openTSLMEncoderCheckpointPath, + bundledName: encoderName, + fileExtension: "safetensors" + ) else { + throw NSError(domain: "OpenTSLMInferenceService", code: 2, userInfo: [NSLocalizedDescriptionKey: "Missing encoder checkpoint '\(encoderName)' in bundle/OpenTSLM or override path"]) + } + guard let projectorURL = resolveAssetURL( + overridePath: Constants.openTSLMProjectorCheckpointPath, + bundledName: projectorName, + fileExtension: "safetensors" + ) else { + throw NSError(domain: "OpenTSLMInferenceService", code: 3, userInfo: [NSLocalizedDescriptionKey: "Missing projector checkpoint '\(projectorName)' in bundle/OpenTSLM or override path"]) + } + + let pipeline = OpenTSLMSPPipeline(hiddenSize: 2048) + try pipeline.loadWeights(encoderURL: encoderURL, projectorURL: projectorURL) + let projected = pipeline.projectSample(sample) + guard let first = projected.first else { + throw NSError(domain: "OpenTSLMInferenceService", code: 5, userInfo: [NSLocalizedDescriptionKey: "Projection returned no tensors"]) + } + eval(first) + GPU.clearCache() + return projected + } + + /// Defaults to the sleep LoRA checkpoint; the ECG paths pass `Constants.openTSLMECGLoRACheckpointName`. + private func applyLoRAIfAvailable( + on llmSession: LLMLocalSession, + checkpointName: String = Constants.openTSLMLoRACheckpointName + ) async -> Bool { + do { + try await OpenTSLMLoRA.applyIfNeeded(on: llmSession, checkpointName: checkpointName) + return true + } catch { + logger.warning("OpenTSLM LoRA apply failed: \(error.localizedDescription, privacy: .public)") + return false + } + } + + private func resolveAssetURL(overridePath: String, bundledName: String, fileExtension: String) -> URL? { + let fileManager = FileManager.default + + if !overridePath.isEmpty { + let expandedPath = NSString(string: overridePath).expandingTildeInPath + let overrideURL = URL(fileURLWithPath: expandedPath) + if fileManager.fileExists(atPath: overrideURL.path) { + return overrideURL + } + } + + if let bundledURL = Bundle.main.url( + forResource: bundledName, + withExtension: fileExtension, + subdirectory: Constants.openTSLMBundleSubdirectory + ) { + return bundledURL + } + + return Bundle.main.url(forResource: bundledName, withExtension: fileExtension) + } + + private func resolveLocalModelDirectory() -> URL? { + let fileManager = FileManager.default + + // Check for override path + if let overridePath = Constants.localModelSourcePathOverride, + let overrideURL = existingDirectoryURL(at: overridePath, fileManager: fileManager) { + return overrideURL + } + + // Check bundled model directory + if let bundledLocalModelURL = Bundle.main.resourceURL { + let bundledDirectory = bundledLocalModelURL.appendingPathComponent(Constants.localModelBundleSubdirectory, isDirectory: true) + if fileManager.fileExists(atPath: bundledDirectory.path) { + return bundledDirectory + } + + // Check if model files are directly in bundle root + let requiredModelFiles = ["config.json", "tokenizer.json", "model.safetensors"] + let bundledRootFiles = requiredModelFiles.allSatisfy { fileName in + fileManager.fileExists(atPath: bundledLocalModelURL.appendingPathComponent(fileName).path) + } + + if bundledRootFiles { + return bundledLocalModelURL + } + } + + // Fallback: detect a downloaded model snapshot from Hugging Face cache + let sanitizedRepoID = Constants.llmModelName.replacingOccurrences(of: "/", with: "--") + let hostSnapshotsPath = "\(Constants.hostHuggingFaceCacheRoot)/models--\(sanitizedRepoID)/snapshots" + let hostSnapshotsURL = URL(fileURLWithPath: hostSnapshotsPath, isDirectory: true) + + if let hostSnapshot = newestSnapshotDirectory(in: hostSnapshotsURL, fileManager: fileManager) { + return hostSnapshot + } + + let snapshotsPath = "~/.cache/huggingface/hub/models--\(sanitizedRepoID)/snapshots" + let snapshotsURL = URL(fileURLWithPath: NSString(string: snapshotsPath).expandingTildeInPath, isDirectory: true) + return newestSnapshotDirectory(in: snapshotsURL, fileManager: fileManager) + } + + private func newestSnapshotDirectory(in snapshotsURL: URL, fileManager: FileManager) -> URL? { + guard fileManager.fileExists(atPath: snapshotsURL.path) else { + return nil + } + + let directoryContents = try? fileManager.contentsOfDirectory( + at: snapshotsURL, + includingPropertiesForKeys: [.contentModificationDateKey], + options: [.skipsHiddenFiles] + ) + + return directoryContents? + .filter { url in + var isDirectory: ObjCBool = false + return fileManager.fileExists(atPath: url.path, isDirectory: &isDirectory) && isDirectory.boolValue + } + .sorted(by: { lhs, rhs in + let lhsDate = (try? lhs.resourceValues(forKeys: [.contentModificationDateKey]).contentModificationDate) ?? .distantPast + let rhsDate = (try? rhs.resourceValues(forKeys: [.contentModificationDateKey]).contentModificationDate) ?? .distantPast + return lhsDate > rhsDate + }) + .first + } + + private func existingDirectoryURL(at rawPath: String, fileManager: FileManager) -> URL? { + let expandedPath = NSString(string: rawPath).expandingTildeInPath + let url = URL(fileURLWithPath: expandedPath, isDirectory: true) + var isDirectory: ObjCBool = false + + guard fileManager.fileExists(atPath: url.path, isDirectory: &isDirectory), isDirectory.boolValue else { + return nil + } + + return url + } + + private func cappedSample(_ sample: OpenTSLMSPSample) -> OpenTSLMSPSample { + let cap = Constants.openTSLMMaxTimeSeriesLength + guard cap > 0 else { return sample } + + let cappedSeries = sample.timeSeries.map { series in + series.count > cap ? Array(series.prefix(cap)) : series + } + guard cappedSeries != sample.timeSeries else { return sample } + + return OpenTSLMSPSample( + prePrompt: sample.prePrompt, + timeSeriesText: sample.timeSeriesText, + timeSeries: cappedSeries, + postPrompt: sample.postPrompt, + label: sample.label, + answer: sample.answer + ) + } + + private func openTSLMSampleExtraLines( + split: SleepEDFDataset.Split, + sampleIndex: Int, + sample: OpenTSLMSPSample, + embeddingsShape: [Int], + loraApplied: Bool + ) -> [String] { + [ + "split: \(split.rawValue)", + "sample_index: \(sampleIndex)", + "series_count: \(sample.timeSeries.count)", + "max_series_length: \(Constants.openTSLMMaxTimeSeriesLength)", + "embeddings_shape: \(embeddingsShape)", + "lora_checkpoint_found: \(OpenTSLMLoRA.resolveLoRAURL() != nil)", + "lora_applied: \(loraApplied ? "yes" : "no")", + "llm_model: \(Constants.llmModelName)", + "llm_integration: \(Constants.openTSLMRunSampleLLMGeneration ? "generate" : "encoder+lora-only")", + ] + } + + private func ecgQACoTSampleExtraLines( + metadata: ECGQACoTSampleMetadata, + embeddingsShape: [Int], + loraApplied: Bool + ) -> [String] { + [ + "source: \(metadata.source)", + "loader: \(metadata.loader)", + "split: \(metadata.split)", + "sample_index: \(metadata.sampleIndex)", + "ecg_id: \(metadata.ecgId)", + "template_id: \(metadata.templateId)", + "question_type: \(metadata.questionType)", + "series_count: \(metadata.seriesCount)", + "samples_per_lead: \(metadata.samplesPerLead)", + "max_series_length: \(Constants.openTSLMMaxTimeSeriesLength)", + "skip_llm_load: \(Constants.skipLLMLoad ? "yes" : "no")", + "run_llm: \(Constants.openTSLMRunSampleLLMGeneration ? "yes" : "no")", + "embeddings_shape: \(embeddingsShape)", + "lora_checkpoint_found: \(OpenTSLMLoRA.resolveLoRAURL(checkpointName: Constants.openTSLMECGLoRACheckpointName) != nil)", + "lora_applied: \(loraApplied ? "yes" : "no")", + ] + } + + private struct LoadedECGQACoTSample { + let sample: OpenTSLMSPSample + let metadata: ECGQACoTSampleMetadata + } + + private struct ECGQACoTSampleMetadata { + let source: String + let loader: String + let split: String + let sampleIndex: Int + let ecgId: String + let templateId: String + let questionType: String + let seriesCount: Int + let samplesPerLead: Int + } + + /// Loads one ECG-QA CoT row via ``ECGQACoTDataset`` (CSV metadata + PTB-XL waveform sidecar). + private func loadECGQACoTSample( + split: ECGQACoTDataset.Split, + sampleIndex: Int + ) throws -> LoadedECGQACoTSample { + if let csvURL = resolveECGCoTCSVURL(split: split), + let waveformsDirectory = resolveECGWaveformsDirectoryURL() { + let dataset = try ECGQACoTDataset( + csvURL: csvURL, + waveformsDirectory: waveformsDirectory, + templateAnswersURL: resolveECGTemplateAnswersURL(), + split: split, + maxRows: max(sampleIndex + 1, 1) + ) + guard dataset.count > 0 else { + throw NSError(domain: "OpenTSLMInferenceService", code: 15, userInfo: [NSLocalizedDescriptionKey: "ECG-QA CoT CSV has no rows"]) + } + + let safeIndex = min(max(sampleIndex, 0), dataset.count - 1) + let sample = try dataset.sample(at: safeIndex) + let row = dataset.rowMetadata(at: safeIndex) + + return LoadedECGQACoTSample( + sample: sample, + metadata: ECGQACoTSampleMetadata( + source: "ecg_qa_cot", + loader: "ECGQACoTDataset", + split: row.split.rawValue, + sampleIndex: safeIndex, + ecgId: String(row.ecgId), + templateId: String(row.templateId), + questionType: row.questionType, + seriesCount: sample.timeSeries.count, + samplesPerLead: sample.timeSeries.first?.count ?? 0 + ) + ) + } + + // Legacy fallback: monolithic JSON exported by inference_ecg.py --export-json. + let legacy = try loadECGQACoTFormattedSample() + return LoadedECGQACoTSample( + sample: legacy.sample, + metadata: ECGQACoTSampleMetadata( + source: legacy.info.source, + loader: "formatted_json", + split: legacy.info.split, + sampleIndex: legacy.info.sampleIndex, + ecgId: legacy.info.ecgId, + templateId: legacy.info.templateId, + questionType: "unknown", + seriesCount: legacy.info.seriesCount, + samplesPerLead: legacy.info.samplesPerLead + ) + ) + } + + private func resolveECGCoTCSVURL(split: ECGQACoTDataset.Split) -> URL? { + if !Constants.openTSLMECGCoTCSVPath.isEmpty { + let expandedPath = NSString(string: Constants.openTSLMECGCoTCSVPath).expandingTildeInPath + let overrideURL = URL(fileURLWithPath: expandedPath) + if FileManager.default.fileExists(atPath: overrideURL.path) { + return overrideURL + } + } + + if split == .test { + return resolveAssetURL( + overridePath: "", + bundledName: Constants.openTSLMECGCoTTestCSVName, + fileExtension: "csv" + ) + } + + return resolveAssetURL( + overridePath: "", + bundledName: split.csvBaseName, + fileExtension: "csv" + ) + } + + private func resolveECGWaveformsDirectoryURL() -> URL? { + if !Constants.openTSLMECGWaveformsPath.isEmpty { + let expandedPath = NSString(string: Constants.openTSLMECGWaveformsPath).expandingTildeInPath + var isDirectory: ObjCBool = false + if FileManager.default.fileExists(atPath: expandedPath, isDirectory: &isDirectory), isDirectory.boolValue { + return URL(fileURLWithPath: expandedPath, isDirectory: true) + } + } + + if let bundledURL = Bundle.main.url( + forResource: Constants.openTSLMECGWaveformsDirectoryName, + withExtension: nil, + subdirectory: Constants.openTSLMBundleSubdirectory + ) { + return bundledURL + } + + return Bundle.main.url(forResource: Constants.openTSLMECGWaveformsDirectoryName, withExtension: nil) + } + + private func resolveECGTemplateAnswersURL() -> URL? { + return resolveAssetURL( + overridePath: "", + bundledName: Constants.openTSLMECGTemplateAnswersName, + fileExtension: "json" + ) + } + + private struct LoadedECGQACoTFormattedSample { + let sample: OpenTSLMSPSample + let info: ECGQACoTFormattedSampleInfo + } + + private struct ECGQACoTFormattedSampleInfo { + let source: String + let split: String + let sampleIndex: Int + let ecgId: String + let templateId: String + let seriesCount: Int + let samplesPerLead: Int + } + + /// Legacy formatted JSON export from ``inference_ecg.py --export-json``. + private func loadECGQACoTFormattedSample() throws -> LoadedECGQACoTFormattedSample { + guard let url = resolveECGQACoTSampleURL() else { + throw NSError( + domain: "OpenTSLMInferenceService", + code: 9, + userInfo: [ + NSLocalizedDescriptionKey: + "Missing ECG-QA CoT sample JSON. Bundle \(Constants.openTSLMECGQACoTSampleName).json " + + "or set HEALTHYLLM_OPEN_TSLM_ECG_JSON to the file exported by inference_ecg.py --export-json.", + ] + ) + } + + let data = try Data(contentsOf: url) + let decoded = try JSONSerialization.jsonObject(with: data) + guard let object = decoded as? [String: Any] else { + throw NSError(domain: "OpenTSLMInferenceService", code: 10, userInfo: [NSLocalizedDescriptionKey: "ECG-QA JSON root must be an object"]) + } + + guard let prePrompt = object["pre_prompt"] as? String, + let postPrompt = object["post_prompt"] as? String, + let timeSeriesText = object["time_series_text"] as? [String], + let label = object["label"] as? String, + let answer = object["answer"] as? String + else { + throw NSError( + domain: "OpenTSLMInferenceService", + code: 11, + userInfo: [NSLocalizedDescriptionKey: "ECG-QA JSON missing pre_prompt/post_prompt/time_series_text/label/answer"] + ) + } + + guard let rawSeries = object["time_series"] as? [[Any]], !rawSeries.isEmpty else { + throw NSError(domain: "OpenTSLMInferenceService", code: 12, userInfo: [NSLocalizedDescriptionKey: "ECG-QA JSON missing time_series"]) + } + + let timeSeries: [[Float]] = try rawSeries.map { lead in + guard !lead.isEmpty else { + throw NSError(domain: "OpenTSLMInferenceService", code: 13, userInfo: [NSLocalizedDescriptionKey: "ECG-QA JSON has empty lead"]) + } + return try lead.map { value in + if let f = value as? Float { return f } + if let d = value as? Double { return Float(d) } + if let i = value as? Int { return Float(i) } + throw NSError(domain: "OpenTSLMInferenceService", code: 14, userInfo: [NSLocalizedDescriptionKey: "Invalid time_series value"]) + } + } + + let samplesPerLead = timeSeries.first?.count ?? 0 + let info = ECGQACoTFormattedSampleInfo( + source: object["source"] as? String ?? "ecg_qa_cot", + split: object["split"] as? String ?? "unknown", + sampleIndex: object["sample_idx"] as? Int ?? -1, + ecgId: String(describing: object["ecg_id"] ?? "unknown"), + templateId: String(describing: object["template_id"] ?? "unknown"), + seriesCount: timeSeries.count, + samplesPerLead: samplesPerLead + ) + + let sample = OpenTSLMSPSample( + prePrompt: prePrompt, + timeSeriesText: timeSeriesText, + timeSeries: timeSeries, + postPrompt: postPrompt, + label: label, + answer: answer + ) + return LoadedECGQACoTFormattedSample(sample: sample, info: info) + } + + private func resolveECGQACoTSampleURL() -> URL? { + if let override = resolveECGJSONURL() { + return override + } + + return resolveAssetURL( + overridePath: "", + bundledName: Constants.openTSLMECGQACoTSampleName, + fileExtension: "json" + ) + } + + private func loadECGSample() throws -> ECGSample { + if let url = resolveECGJSONURL() { + return try ECGSample.load(from: url) + } + + return ECGSample.hardcoded(sampleLength: Constants.hardcodedECGSampleLength) + } + + private func resolveECGJSONURL() -> URL? { + let overridePath = Constants.openTSLMECGJSONPath + guard !overridePath.isEmpty else { + return nil + } + + let fileManager = FileManager.default + let expandedPath = NSString(string: overridePath).expandingTildeInPath + let overrideURL = URL(fileURLWithPath: expandedPath) + guard fileManager.fileExists(atPath: overrideURL.path) else { + return nil + } + + return overrideURL + } + + // PTB-XL / ECG-QA standard 12-lead order — must match what the encoder was + // trained to receive in the per-lead text labels. + private static let ecgLeadNames = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] + + private static let ecgTargetSamplingRate: Double = 100 // Hz — matches ECG-QA training (`[::5]` from 500 Hz) + /// Window length per lead. Capped to the global series-length budget so we don't + /// generate samples we'd only truncate; ECG-QA training was on 10 s windows but + /// 12 × 1000 over-runs the iOS 6 GB process limit at prefill time. + private static var ecgSamplesPerLead: Int { Constants.openTSLMMaxTimeSeriesLength } + + private func makeOpenTSLMSample(from ecg: ECGSample) -> OpenTSLMSPSample { + // Downsample to 100 Hz and window to 1000 samples (10 s) to match ECG-QA training. + let downsampled = Self.downsampleAndWindow( + ecg.voltages, + inputRate: ecg.samplingFrequency, + outputRate: Self.ecgTargetSamplingRate, + count: Self.ecgSamplesPerLead + ) + let stats = Self.statistics(of: downsampled) + let normalizedLead = downsampled.map { Float(($0 - stats.mean) / stats.std) } + + // Apple Watch is hardware-single-lead — replicate Lead I across all 12 channels + // so the encoder receives the 12-series shape it was trained on. The per-lead + // mean/std in each label are therefore identical, by construction; that's an + // acknowledged distribution shift vs. real PTB-XL multi-lead recordings. + var timeSeries: [[Float]] = [] + timeSeries.reserveCapacity(Self.ecgLeadNames.count) + var timeSeriesText: [String] = [] + timeSeriesText.reserveCapacity(Self.ecgLeadNames.count) + for name in Self.ecgLeadNames { + timeSeries.append(normalizedLead) + timeSeriesText.append("This is ECG Lead \(name), it has mean \(String(format: "%.4f", stats.mean)) and std \(String(format: "%.4f", stats.std)):") + } + + return OpenTSLMSPSample( + prePrompt: Self.ecgPrePrompt( + clinicalContext: ecg.clinicalContext, + question: ecg.question + ), + timeSeriesText: timeSeriesText, + timeSeries: timeSeries, + postPrompt: Self.ecgPostPrompt(possibleAnswers: ecg.possibleAnswers), + label: ecg.classification ?? "unknown", + answer: ecg.summary + ) + } + + /// Verbatim from OpenTSLM `src/time_series_datasets/ecg_qa/ECGQACoTQADataset.py:_get_pre_prompt`, + /// with the runtime-supplied `clinical_context` and `question` substituted. + private static func ecgPrePrompt(clinicalContext: String, question: String) -> String { + """ + You are an expert cardiologist analyzing an ECG (electrocardiogram). + + Clinical Context: \(clinicalContext) + + Your task is to examine the ECG signal and answer the following medical question: + + Question: \(question) + + Instructions: + - Begin by analyzing the time series without assuming a specific answer. + - Think step-by-step about what the observed patterns suggest regarding the cardiac condition. + - Write your rationale as a single, natural paragraph — do not use bullet points, numbered steps, or section headings. + - Do **not** mention any final answer until the very end. + - Consider the ECG morphology, intervals, and any abnormalities that relate to the question. + """ + } + + /// Verbatim from `ECGQACoTQADataset.py:_get_post_prompt`. Two branches: + /// - if `possibleAnswers` is non-empty → the templated branch, which lists the closed + /// answer set the model was trained to pick from. + /// - otherwise → the open-ended branch. + /// Both end with literal `"Answer: ` (open quote + trailing space, no closing quote) — + /// the model continues from there with its rationale and concludes `Answer: