mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2025-12-22 19:09:57 +00:00
Workaround for the mlx kv-cache bug.
This commit is contained in:
@@ -76,6 +76,9 @@ def main():
|
||||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
|
||||
@@ -205,6 +205,9 @@ def main():
|
||||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||
# The following line gets around it for now.
|
||||
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||
model = models.Lm(lm_config)
|
||||
model.set_dtype(mx.bfloat16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user