API Reference
Detailed documentation for all public classes and functions in engram-peft, 100% aligned with the DeepSeek Engram paper and official implementation.
Configuration
EngramConfig
engram_peft.config.EngramConfig
Configuration class for Engram PEFT module. Inherits from transformers.PretrainedConfig. All default values exactly match the specifications in the Engram paper Appendix A Table 5.
Parameters:
- engram_vocab_size_per_ngram (List[int], default: [1131200, 1131200]): Total engram vocabulary size split per N-gram order.
- ngram_sizes (List[int], default: [2, 3]): List of N-gram orders to use (e.g., [2, 3] means 2-grams and 3-grams).
- n_head_per_ngram (int, default: 8): Number of hash heads per N-gram order.
- embedding_dim (int, default: 1280): Dimension of the Engram retrieval embedding.
- enable_tokenizer_compression (bool, default: True): Whether to use NFKC/Lowercase normalization for token grouping.
- target_layers (List[int], default: [2, 15]): Transformer layers where Engram modules are injected.
- target_modules (Optional[Union[List[str], str]], default: None): Specific module names or regex patterns to target for injection.
- hc_mult (int, default: 4): Multi-head hyper-connection expansion factor.
- combine_mhc (bool, default: True): Whether to combine multi-head hyper-connections.
- conv_kernel_size (int, default: 4): Convolution kernel size for short-term context.
- conv_dilation (Optional[int], default: None): Convolution dilation (defaults to max(ngram_sizes)).
- conv_zero_init (bool, default: True): Initialize convolution weights to zero to ensure identity mapping at start.
- learning_rate_multiplier (float, default: 5.0): LR multiplier for sparse embedding parameters.
- tokenizer_name_or_path (Optional[str], default: None): Tokenizer used for precomputing hashes. Recommended to set explicitly (e.g., "deepseek-ai/DeepSeek-V3").
- seed (int, default: 0): Random seed for deterministic hashing primes.
- weight_decay (float, default: 0.0): Weight decay for Engram parameters.
- gating_zero_init (bool, default: False): Whether to initialize gating parameters with zeros.
- hidden_size (Optional[int], default: None): The hidden dimension of the base model. Auto-detected if not provided.
- pad_id (Optional[int], default: None): The padding token ID. Auto-detected if not provided.
- compressed_vocab_size (Optional[int], default: None): Resolved size of the hashing vocabulary. Automatically set and saved after first initialization.
- layer_container_path (Optional[str], default: None): Explicit dot-separated path to the nn.ModuleList containing transformer layers (e.g., "model.layers"). If provided, it bypasses the automatic architecture discovery.
Example Usage:
from engram_peft import EngramConfig
config = EngramConfig(
target_layers=[2, 11, 20],
embedding_dim=1024,
learning_rate_multiplier=5.0
)
Model Wrapping
get_engram_model
engram_peft.model.get_engram_model(model, config, tokenizer=None, wrap_peft=False, train_mode=None)
Injects Engram layers into a base Transformer model and configures which backbone parameters remain trainable.
Args:
- model (Union[PreTrainedModel, nn.Module]): The base model to wrap. Supports standard Hugging Face models and custom torch.nn.Module architectures.
- config (EngramConfig): Engram configuration.
- tokenizer (Optional[PreTrainedTokenizer]): Tokenizer for vocabulary/compression.
- wrap_peft (bool, default: False): Backward-compatible alias for train_mode="preserve_trainable".
- train_mode (Literal["engram_only", "preserve_trainable", "full_finetune"], optional): Controls backbone trainability.
- engram_only: Freeze the backbone and train only Engram.
- preserve_trainable: Preserve parameters that were already trainable before wrapping (e.g., LoRA), then add trainable Engram layers.
- full_finetune: Train the full backbone together with Engram.
Returns:
- EngramModel: The wrapped model with injected forward hooks.
Examples:
# Pure Engram PEFT
model = get_engram_model(base_model, config, tokenizer, train_mode="engram_only")
# LoRA + Engram
model = get_engram_model(model, config, tokenizer, train_mode="preserve_trainable")
# Full finetuning + Engram
model = get_engram_model(base_model, config, tokenizer, train_mode="full_finetune")
EngramModel
engram_peft.model.EngramModel
The wrapper class for the base model. Handles dynamic hook management and weight serialization.
Methods:
- print_trainable_parameters(): Prints trainable counts for backbone, Engram, and total parameters.
- add_adapter(adapter_name: str, config: EngramConfig): Adds a new set of Engram weights with its own configuration.
- set_adapter(adapter_name: str): Switches the active knowledge pack to the specified adapter.
- create_optimizer(base_learning_rate: float, **optimizer_kwargs): Returns a MixedOptimizer with configurable backbone/Engram optimizer groups.
- create_scheduler(optimizer, num_steps, warmup_steps): Returns the paper-aligned Step Decay scheduler.
- save_pretrained(save_directory: str): Saves ONLY the active Engram weights and configuration.
- from_pretrained(base_model, engram_path): Loads Engram weights onto a base model.
- unload_engram(): Dynamically removes all PEFT hooks (reverts to base model).
- load_engram(engram_path=None): Re-installs hooks and optionally loads weights.
- load_weights_flexible(checkpoint_path, source_config_path=None, layer_mapping=None, reuse_structural=False): Loads weights from a checkpoint even if configurations (layers, buckets, n-grams) differ.
- remap_from_corpus(corpus, checkpoint_path, source_config_path=None, layer_mapping=None, tokenizer=None, batch_size=1024): "Best-effort" remapping for cases where seeds or tokenizers differ, using a reference corpus to align indices.
Data Utilities
EngramDataCollator
engram_peft.collator.EngramDataCollator
A high-performance data collator that precomputes multi-head hash indices on the CPU during data loading, ensuring the GPU is dedicated to training.
Args:
- tokenizer: Hugging Face tokenizer.
- config: EngramConfig instance.
- compressor: CompressedTokenizer instance (optional).
Example Usage:
from engram_peft import EngramDataCollator
from transformers import Trainer
collator = EngramDataCollator(tokenizer=tokenizer, config=config)
trainer = Trainer(..., data_collator=collator)
EngramTrainer
engram_peft.trainer.EngramTrainer
Trainer subclass that handles sparse gradient clipping and can build Engram's mixed optimizer automatically.
Notable Args:
- optimizer_kwargs (Optional[Dict[str, Any]]): Extra keyword arguments forwarded to model.create_optimizer(...) / get_optimizer(...). Use this to configure layered optimizer behavior when relying on the trainer's default optimizer creation path.
Example Usage:
trainer = EngramTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
optimizer_kwargs={
"backbone_learning_rate": 5e-5,
"engram_dense_learning_rate": 4e-4,
"engram_sparse_learning_rate": 2e-3,
},
)
Optimization
get_optimizer
engram_peft.utils.get_optimizer(model, base_learning_rate=4e-4, ...)
Creates a MixedOptimizer with separate optimizer groups for:
- Engram sparse embeddings
- Engram dense parameters
- Backbone parameters
Supported optimizer specs:
- Built-in strings: "adam", "adamw", "sgd", "sparse_adam"
- Optimizer classes
- Custom builder callables
Example Usage:
optimizer = get_optimizer(
model,
backbone_learning_rate=5e-5,
engram_dense_learning_rate=4e-4,
engram_sparse_learning_rate=2e-3,
backbone_optimizer="adamw",
engram_dense_optimizer="adam",
engram_sparse_optimizer="sparse_adam",
)
get_trainable_param_groups
engram_peft.utils.get_trainable_param_groups(model)
Returns a dictionary with three trainable parameter lists:
- backbone
- engram_dense
- engram_sparse
get_scheduler
engram_peft.utils.get_scheduler(optimizer, num_training_steps, warmup_steps=0)
Returns a LambdaLR scheduler implementing the Step Decay schedule from the DeepSeek paper (decay at 80% and 90% progress).
Core Components
EngramLayer
engram_peft.layer.EngramLayer
The core torch module containing retrieval embeddings, context-aware gating, and short-term convolutions. Typically managed automatically via get_engram_model.