Skip to content

Commit 1d6854b

Browse files
authored
fix gemma3 (#3772)
* fix gemma3 * add comment * fix gemma3 * fix transformers>=4.54.0
1 parent 4450cd9 commit 1d6854b

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

lmdeploy/pytorch/configurations/gemma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,7 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
3131
"""Build gemma."""
3232
hf_config.text_config.architectures = ['Gemma3ForCausalLM']
3333
cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
34+
# gemma 3 does not enable sliding window on every layers
35+
cfg.sliding_window = -1
3436
cfg.hf_config = hf_config
3537
return cfg

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ async def async_loop(self):
11891189
has_runable_event=has_runable_event,
11901190
inputs_maker=inputs_maker)
11911191
except Exception as e:
1192-
logger.error(f'exception happened: {type(e)} {e}')
1192+
logger.exception(f'exception happened: {type(e)} {e}')
11931193
finally:
11941194
self._loop_finally()
11951195

lmdeploy/pytorch/engine/executor/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def __init__(self,
2727
device_type: str = 'cuda'):
2828
"""Initialize Executor."""
2929
cache_config.window_size = model_config.sliding_window
30+
if cache_config.window_size is not None and cache_config.window_size > 0:
31+
# do not support sliding window prefix caching
32+
logger.warning('Sliding window prefix caching is not supported.')
33+
cache_config.enable_prefix_caching = False
3034
self.model_config = model_config
3135
self.cache_config = cache_config
3236
self.backend_config = backend_config

lmdeploy/pytorch/models/gemma.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(self,
5656
if hasattr(config, 'query_pre_attn_scalar'):
5757
self.scaling = config.query_pre_attn_scalar**-0.5
5858
if self.model_type == 'gemma3_text':
59-
is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
59+
sliding_window_pattern = getattr(config, 'sliding_window_pattern', 6)
60+
is_sliding = bool((layer_idx + 1) % sliding_window_pattern)
6061
self.sliding_window = (getattr(config, 'sliding_window', -1) if is_sliding else -1)
6162
else:
6263
self.sliding_window = (getattr(config, 'sliding_window', -1) if not bool(layer_idx % 2) else -1)
@@ -388,7 +389,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
388389
emb_type = RopeType.DynamicNTKScaling
389390
else:
390391
raise RuntimeError(f'Unsupported rope type: {rope_type}')
391-
scaling_factor = rope_scaling.get('scaling_factor', scaling_factor)
392+
scaling_factor = rope_scaling.get('scaling_factor', rope_scaling.get('factor', scaling_factor))
392393

393394
rope_dim = config.head_dim
394395
rope_max_pos_emb = config.max_position_embeddings
@@ -406,8 +407,8 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
406407
rope_dim,
407408
rope_max_pos_emb,
408409
config.rope_local_base_freq,
409-
scaling_factor,
410-
emb_type=emb_type,
410+
1.0,
411+
emb_type=RopeType.LinearScaling,
411412
)
412413

413414
def forward(

0 commit comments

Comments
 (0)