diff --git a/Package.resolved b/Package.resolved index 7ed371b..029c27b 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,15 +1,6 @@ { - "originHash" : "f2f0ba1d1b9625bd5147b2fbd7b82236dac35ee1baa399fcf5c76b22fd428bb8", + "originHash" : "04f1ba4db85c992fd8da1384e5a32a383eaef36b465ed07a132a745108b72fdf", "pins" : [ - { - "identity" : "async-http-client", - "kind" : "remoteSourceControl", - "location" : "https://github.com/swift-server/async-http-client.git", - "state" : { - "revision" : "2fc4652fb4689eb24af10e55cabaa61d8ba774fd", - "version" : "1.32.0" - } - }, { "identity" : "eventsource", "kind" : "remoteSourceControl", @@ -29,57 +20,48 @@ } }, { - "identity" : "partialjsondecoder", + "identity" : "llama.swift", "kind" : "remoteSourceControl", - "location" : "https://github.com/mattt/PartialJSONDecoder.git", + "location" : "https://github.com/mattt/llama.swift", "state" : { - "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", - "version" : "1.0.0" + "revision" : "995e96d8373e7a503463b698002924196e1926df", + "version" : "2.8808.0" } }, { - "identity" : "swift-algorithms", + "identity" : "mlx-swift", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-algorithms.git", + "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023", - "version" : "1.2.1" + "revision" : "61b9e011e09a62b489f6bd647958f1555bdf2896", + "version" : "0.31.3" } }, { - "identity" : "swift-asn1", + "identity" : "mlx-swift-lm", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-asn1.git", + "location" : "https://github.com/ml-explore/mlx-swift-lm", "state" : { - "revision" : "810496cf121e525d660cd0ea89a758740476b85f", - "version" : "1.5.1" + "revision" : "1c05248bb0899e2a7a4962b84d319cf12f4e12aa", + "version" : "3.31.3" } }, { - "identity" : "swift-async-algorithms", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-async-algorithms.git", - "state" : { - "revision" : "9d349bcc328ac3c31ce40e746b5882742a0d1272", - "version" : "1.1.3" - } - }, - { - "identity" : "swift-atomics", + "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-atomics.git", + "location" : "https://github.com/mattt/PartialJSONDecoder.git", "state" : { - "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7", - "version" : "1.3.0" + "revision" : "e4d389e6bcc6771bb988d1a8a17695d8bfa97172", + "version" : "1.0.0" } }, { - "identity" : "swift-certificates", + "identity" : "swift-asn1", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-certificates.git", + "location" : "https://github.com/apple/swift-asn1.git", "state" : { - "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", - "version" : "1.18.0" + "revision" : "9f542610331815e29cc3821d3b6f488db8715517", + "version" : "1.6.0" } }, { @@ -91,148 +73,67 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-configuration", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-configuration.git", - "state" : { - "revision" : "be76c4ad929eb6c4bcaf3351799f2adf9e6848a9", - "version" : "1.2.0" - } - }, { "identity" : "swift-crypto", "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-crypto.git", "state" : { - "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", - "version" : "4.2.0" - } - }, - { - "identity" : "swift-distributed-tracing", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-distributed-tracing.git", - "state" : { - "revision" : "e109d8b5308d0e05201d9a1dd1c475446a946a11", - "version" : "1.4.0" - } - }, - { - "identity" : "swift-http-structured-headers", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-http-structured-headers.git", - "state" : { - "revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b", - "version" : "1.6.0" + "revision" : "bb4ba815dab96d4edc1e0b86d7b9acf9ff973a84", + "version" : "4.3.1" } }, { - "identity" : "swift-http-types", + "identity" : "swift-huggingface", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-http-types.git", + "location" : "https://github.com/huggingface/swift-huggingface.git", "state" : { - "revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca", - "version" : "1.5.1" + "revision" : "b721959445b617d0bf03910b2b4aced345fd93bf", + "version" : "0.9.0" } }, { - "identity" : "swift-log", + "identity" : "swift-jinja", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-log.git", + "location" : "https://github.com/huggingface/swift-jinja.git", "state" : { - "revision" : "bbd81b6725ae874c69e9b8c8804d462356b55523", - "version" : "1.10.1" - } - }, - { - "identity" : "swift-nio", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-nio.git", - "state" : { - "revision" : "e932d3c4d8f77433c8f7093b5ebcbf91463948a0", - "version" : "2.95.0" - } - }, - { - "identity" : "swift-nio-extras", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-nio-extras.git", - "state" : { - "revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b", - "version" : "1.32.1" - } - }, - { - "identity" : "swift-nio-http2", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-nio-http2.git", - "state" : { - "revision" : "b6571f3db40799df5a7fc0e92c399aa71c883edd", - "version" : "1.40.0" - } - }, - { - "identity" : "swift-nio-ssl", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-nio-ssl.git", - "state" : { - "revision" : "173cc69a058623525a58ae6710e2f5727c663793", - "version" : "2.36.0" - } - }, - { - "identity" : "swift-nio-transport-services", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-nio-transport-services.git", - "state" : { - "revision" : "60c3e187154421171721c1a38e800b390680fb5d", - "version" : "1.26.0" + "revision" : "0aeefadec459ce8e11a333769950fb86183aca43", + "version" : "2.3.5" } }, { "identity" : "swift-numerics", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics.git", + "location" : "https://github.com/apple/swift-numerics", "state" : { "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", "version" : "1.1.1" } }, { - "identity" : "swift-service-context", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-service-context.git", - "state" : { - "revision" : "d0997351b0c7779017f88e7a93bc30a1878d7f29", - "version" : "1.3.0" - } - }, - { - "identity" : "swift-service-lifecycle", + "identity" : "swift-syntax", "kind" : "remoteSourceControl", - "location" : "https://github.com/swift-server/swift-service-lifecycle", + "location" : "https://github.com/swiftlang/swift-syntax.git", "state" : { - "revision" : "89888196dd79c61c50bca9a103d8114f32e1e598", - "version" : "2.10.1" + "revision" : "0687f71944021d616d34d922343dcef086855920", + "version" : "600.0.1" } }, { - "identity" : "swift-syntax", + "identity" : "swift-transformers", "kind" : "remoteSourceControl", - "location" : "https://github.com/swiftlang/swift-syntax.git", + "location" : "https://github.com/huggingface/swift-transformers", "state" : { - "revision" : "0687f71944021d616d34d922343dcef086855920", - "version" : "600.0.1" + "revision" : "b38443e44d93eca770f2eb68e2a4d0fa100f9aa2", + "version" : "1.3.0" } }, { - "identity" : "swift-system", + "identity" : "yyjson", "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-system", + "location" : "https://github.com/ibireme/yyjson.git", "state" : { - "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df", - "version" : "1.6.4" + "revision" : "8b4a38dc994a110abaec8a400615567bd996105f", + "version" : "0.12.0" } } ], diff --git a/Package.swift b/Package.swift index 933f837..c14c1c2 100644 --- a/Package.swift +++ b/Package.swift @@ -42,7 +42,7 @@ let package = Package( .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). - .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "2.25.5"), + .package(url: "https://github.com/ml-explore/mlx-swift-lm", from: "3.31.3"), .package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0"), .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.24.0"), ], diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 0ef37ef..6d740d6 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,7 @@ import Foundation import MLXVLM import Tokenizers import Hub + import HuggingFace /// Wrapper to store model availability state in NSCache. private final class CachedModelState: NSObject, @unchecked Sendable { @@ -172,6 +173,93 @@ import Foundation /// Shared cache across MLXLanguageModel instances. private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + /// Downloader bridge for HuggingFace.HubClient. + private struct HubClientDownloader: MLXLMCommon.Downloader { + enum Error: Swift.Error { + case invalidRepositoryID(String) + } + + let client: HubClient + + init(client: HubClient = .default) { + self.client = client + } + + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + guard let repoID = Repo.ID(rawValue: id) else { + throw Error.invalidRepositoryID(id) + } + + // HubClient resolves to cache by default; `useLatest` is currently not exposed. + return try await client.downloadSnapshot( + of: repoID, + revision: revision ?? "main", + matching: patterns, + progressHandler: { @MainActor progress in + progressHandler(progress) + } + ) + } + } + + /// Tokenizer loader for local-directory MLX models when Hugging Face macros are unavailable. + private struct LocalTokenizersLoader: MLXLMCommon.TokenizerLoader { + func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer { + let upstream = try await Tokenizers.AutoTokenizer.from(modelFolder: directory) + return TokenizerBridge(upstream) + } + + private struct TokenizerBridge: MLXLMCommon.Tokenizer { + let upstream: any Tokenizers.Tokenizer + + init(_ upstream: any Tokenizers.Tokenizer) { + self.upstream = upstream + } + + func encode(text: String, addSpecialTokens: Bool) -> [Int] { + upstream.encode(text: text, addSpecialTokens: addSpecialTokens) + } + + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { + upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens) + } + + func convertTokenToId(_ token: String) -> Int? { + upstream.convertTokenToId(token) + } + + func convertIdToToken(_ id: Int) -> String? { + upstream.convertIdToToken(id) + } + + var bosToken: String? { upstream.bosToken } + var eosToken: String? { upstream.eosToken } + var unknownToken: String? { upstream.unknownToken } + + func applyChatTemplate( + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]?, + additionalContext: [String: any Sendable]? + ) throws -> [Int] { + do { + return try upstream.applyChatTemplate( + messages: messages, + tools: tools, + additionalContext: additionalContext + ) + } catch Tokenizers.TokenizerError.missingChatTemplate { + throw MLXLMCommon.TokenizerError.missingChatTemplate + } + } + } + } + // MARK: - MLXLanguageModel /// A language model that runs locally using MLX. @@ -680,10 +768,10 @@ import Foundation return try await modelCache.context(for: key) { if let directory { - return try await loadModel(directory: directory) + return try await loadModel(from: directory, using: LocalTokenizersLoader()) } - return try await loadModel(hub: hub ?? HubApi(), id: modelId) + return try await loadModel(from: HubClientDownloader(), using: LocalTokenizersLoader(), id: modelId) } } @@ -1714,7 +1802,7 @@ import Foundation private struct MLXTokenBackend: TokenBackend { let model: any MLXLMCommon.LanguageModel - let tokenizer: any Tokenizer + let tokenizer: any MLXLMCommon.Tokenizer var state: MLXLMCommon.LMOutput.State? var cache: [MLXLMCommon.KVCache] var processor: MLXLMCommon.LogitProcessor? @@ -1795,7 +1883,7 @@ import Foundation private static func buildEndTokens( eosTokenId: Int, - tokenizer: any Tokenizer, + tokenizer: any MLXLMCommon.Tokenizer, configuration: ModelConfiguration ) -> Set { var tokens: Set = [eosTokenId] @@ -1816,13 +1904,13 @@ import Foundation func isSpecialToken(_ token: Int) -> Bool { // Use swift-transformers' own special token registry (skipSpecialTokens) instead of guessing. - let raw = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + let raw = tokenizer.decode(tokenIds: [token], skipSpecialTokens: false) guard !raw.isEmpty else { return false } - let filtered = tokenizer.decode(tokens: [token], skipSpecialTokens: true) + let filtered = tokenizer.decode(tokenIds: [token], skipSpecialTokens: true) return filtered.isEmpty } - private static func buildTokensExcludedFromRepetitionPenalty(tokenizer: any Tokenizer) -> Set { + private static func buildTokensExcludedFromRepetitionPenalty(tokenizer: any MLXLMCommon.Tokenizer) -> Set { let excludedTexts = ["{", "}", "[", "]", ",", ":", "\""] var excluded = Set() excluded.reserveCapacity(excludedTexts.count * 2) @@ -1842,7 +1930,7 @@ import Foundation } func tokenText(_ token: Int) -> String? { - let decoded = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + let decoded = tokenizer.decode(tokenIds: [token], skipSpecialTokens: false) return decoded.isEmpty ? nil : decoded }