Skip to content

Commit e65c288

Browse files
authored
Configure llama temperature only for non-greedy samplers (#18)
1 parent ed77ac6 commit e65c288

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,6 @@ import Foundation
308308
}
309309
defer { llama_sampler_free(sampler) }
310310

311-
// Use temperature from options if provided, otherwise use model's default
312-
let effectiveTemperature = options.temperature.map { Float($0) } ?? temperature
313-
llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature))
314-
315311
// Use sampling parameters from options if provided
316312
if let sampling = options.sampling {
317313
switch sampling.mode {
@@ -321,12 +317,18 @@ import Foundation
321317
case .topK(let k, let seed):
322318
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(Int32(k)))
323319
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1))
320+
if let temperature = options.temperature {
321+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature)))
322+
}
324323
if let seed = seed {
325324
llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed)))
326325
}
327326
case .nucleus(let threshold, let seed):
328327
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(0)) // Disable top-k
329328
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(Float(threshold), 1))
329+
if let temperature = options.temperature {
330+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature)))
331+
}
330332
if let seed = seed {
331333
llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed)))
332334
}
@@ -450,10 +452,6 @@ import Foundation
450452
}
451453
defer { llama_sampler_free(sampler) }
452454

453-
// Use temperature from options if provided, otherwise use model's default
454-
let effectiveTemperature = options.temperature.map { Float($0) } ?? self.temperature
455-
llama_sampler_chain_add(sampler, llama_sampler_init_temp(effectiveTemperature))
456-
457455
// Use sampling parameters from options if provided
458456
if let sampling = options.sampling {
459457
switch sampling.mode {
@@ -463,12 +461,18 @@ import Foundation
463461
case .topK(let k, let seed):
464462
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(Int32(k)))
465463
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(1.0, 1))
464+
if let temperature = options.temperature {
465+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature)))
466+
}
466467
if let seed = seed {
467468
llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed)))
468469
}
469470
case .nucleus(let threshold, let seed):
470471
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(0)) // Disable top-k
471472
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(Float(threshold), 1))
473+
if let temperature = options.temperature {
474+
llama_sampler_chain_add(sampler, llama_sampler_init_temp(Float(temperature)))
475+
}
472476
if let seed = seed {
473477
llama_sampler_chain_add(sampler, llama_sampler_init_dist(UInt32(seed)))
474478
}

Tests/AnyLanguageModelTests/LlamaLanguageModelTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,19 @@ import Testing
125125
// Response should be limited by max tokens
126126
#expect(!response.content.isEmpty)
127127
}
128+
129+
@Test func greedySamplingWithTemperature() async throws {
130+
let session = LanguageModelSession(model: model)
131+
let options = GenerationOptions(
132+
sampling: .greedy,
133+
temperature: 0.7,
134+
maximumResponseTokens: 50
135+
)
136+
let response = try await session.respond(
137+
to: "Tell me a fact",
138+
options: options
139+
)
140+
#expect(!response.content.isEmpty)
141+
}
128142
}
129143
#endif // Llama

0 commit comments

Comments
 (0)