3.2 Instruction Tuning & Alignment
๐ฏ Learning Objectives
- Understand the multi-stage LLM training pipeline
- Learn about pre-training data sources and composition
- Explore instruction tuning and RLHF processes
- Analyze data quality, scale, and ethical considerations
๐ญ LLM Training Pipeline
๐ Complete Training Process
๐ฑ Stage 1: Pre-training
Goal: Learn general language understanding from massive text data
- Data: Trillions of tokens from diverse sources
- Objective: Next-token prediction (autoregressive)
- Duration: Weeks to months on thousands of GPUs
- Output: Base model with broad knowledge
GPT-3 Example: 300B tokens, 96 layers, ~$4.6M compute cost
โฌ๏ธ
๐ฏ Stage 2: Instruction Tuning (SFT)
Goal: Teach the model to follow instructions and behave helpfully
- Data: High-quality instruction-response pairs
- Size: 10K-100K examples (much smaller than pre-training)
- Method: Supervised fine-tuning on human demonstrations
- Output: Model that can follow instructions
Example: "Explain quantum physics" โ Detailed, helpful explanation
โฌ๏ธ
๐๏ธ Stage 3: RLHF (Reinforcement Learning from Human Feedback)
Goal: Align model outputs with human preferences and values
- Reward Model: Train a model to predict human preferences
- PPO Training: Use reinforcement learning to optimize for human-preferred outputs
- Safety: Reduce harmful, biased, or unhelpful responses
- Output: Aligned model ready for deployment
Key Innovation: Makes models more helpful, harmless, and honest
โฌ๏ธ
๐ Stage 4: Deployment & Monitoring
Goal: Serve the model to users while continuously improving
- Infrastructure: Scalable serving with load balancing
- Monitoring: Track performance, safety, and user satisfaction
- Updates: Regular fine-tuning with new data and feedback
- Safety: Content filtering and abuse detection
๐ Pre-training Data Sources & Composition
๐งฉ Typical LLM Training Data Mix
๐ Web Content
- CommonCrawl: Petabyte-scale web scrapes
- Wikipedia: High-quality encyclopedic content
- Forums: Reddit, Stack Overflow discussions
- Quality: Filtered for language, duplicates, toxic content
~1T
Tokens
50+
Languages
๐ Books & Literature
- BookCorpus: 10K+ books
- Project Gutenberg: Public domain literature
- OpenLibrary: Diverse literary works
- Benefit: Long-form reasoning, narrative understanding
100B+
Tokens
High
Quality
๐ฌ Academic Content
- ArXiv: Scientific papers and preprints
- PubMed: Medical literature
- Academic journals: Peer-reviewed research
- Value: Technical knowledge, formal reasoning
50B+
Tokens
Expert
Level
๐ป Code & Programming
- GitHub: Open source repositories
- GitLab, Bitbucket: Additional code sources
- Documentation: API docs, tutorials
- Impact: Programming abilities, structured thinking
25B+
Tokens
100+
Languages
โ ๏ธ Data Quality Challenges
- Noise: Web content contains errors, spam, low-quality text
- Bias: Training data reflects societal biases and stereotypes
- Privacy: Personal information may be inadvertently included
- Copyright: Legal concerns around use of copyrighted content
- Duplication: Same content appears multiple times, affecting training
๐๏ธ RLHF: Reinforcement Learning from Human Feedback
๐ RLHF Process Flow
1. Collect Comparisons
Humans rank model outputs
โ
Humans rank model outputs
2. Train Reward Model
Predict human preferences
โ
Predict human preferences
3. PPO Training
Optimize for high rewards
Optimize for high rewards
This iterative process aligns the model with human values and preferences
| RLHF Component | Purpose | Data Requirements | Key Challenges |
|---|---|---|---|
| Human Annotations | Provide preference signals | 10K-100K comparisons | Consistency, cost, scalability |
| Reward Model | Score outputs by human preference | Same base model architecture | Reward hacking, generalization |
| PPO Policy | Optimize for reward while staying close to SFT model | Continuous interaction with reward model | Training stability, KL divergence control |
| Safety Filtering | Prevent harmful outputs | Adversarial prompts, red team data | Balancing helpfulness vs safety |
๐ค RLHF Pseudocode
class RLHFTrainer:
def __init__(self, base_model, reward_model):
self.policy = base_model.copy()
self.reference_model = base_model.copy() # Frozen
self.reward_model = reward_model
self.ppo_optimizer = PPO()
def train_step(self, prompts):
# Generate responses from current policy
responses = self.policy.generate(prompts)
# Get rewards from reward model
rewards = self.reward_model.score(prompts, responses)
# Calculate KL penalty (stay close to reference model)
ref_logprobs = self.reference_model.logprobs(prompts, responses)
policy_logprobs = self.policy.logprobs(prompts, responses)
kl_penalty = (policy_logprobs - ref_logprobs).mean()
# Combine rewards with KL penalty
adjusted_rewards = rewards - self.kl_coeff * kl_penalty
# PPO update
policy_loss = self.ppo_optimizer.compute_loss(
prompts, responses, adjusted_rewards, policy_logprobs
)
# Update policy
policy_loss.backward()
self.ppo_optimizer.step()
return {
"reward": rewards.mean(),
"kl_div": kl_penalty,
"policy_loss": policy_loss
}
# Training loop
for epoch in range(num_epochs):
for batch_prompts in dataloader:
metrics = trainer.train_step(batch_prompts)
log_metrics(metrics)
๐งน Data Processing & Quality Control
๐ Content Filtering
- Language Detection: Filter non-target languages
- Quality Scoring: Remove low-quality, spam content
- Toxicity Detection: Filter harmful, offensive content
- Privacy Scrubbing: Remove PII, sensitive data
Tools: Classifiers, regex patterns, blocklists
๐ Deduplication
- Exact Matching: Remove identical documents
- Near-Duplicate Detection: Fuzzy matching algorithms
- Sentence-Level: Remove repeated sentences
- Impact: Prevents memorization, improves generalization
Techniques: MinHash, LSH, Jaccard similarity
๐ค Tokenization
- Subword Tokenization: BPE, SentencePiece
- Vocabulary Size: 32K-100K tokens
- Special Tokens: <start>, <end>, <unk>
- Efficiency: Balance compression vs interpretability
Goal: ~3-4 characters per token for English
โ๏ธ Ethical Considerations
- Consent: Use data with appropriate permissions
- Bias Mitigation: Balance representation across groups
- Copyright Respect: Avoid unauthorized copyrighted content
- Transparency: Document data sources and processing
Standards: Data governance frameworks, ethical AI guidelines
๐งผ Data Processing Pipeline Example
import re
from datasets import Dataset
from transformers import AutoTokenizer
class DataProcessor:
def __init__(self, tokenizer_name="gpt2"):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.min_length = 100 # Minimum document length
self.max_length = 2048 # Maximum sequence length
def filter_quality(self, text):
"""Filter low-quality content"""
# Basic quality checks
if len(text) < self.min_length:
return False
# Check for reasonable word/character ratio
words = text.split()
if len(text) / len(words) < 3: # Too many short words
return False
# Check for excessive repetition
lines = text.split('\n')
unique_lines = set(lines)
if len(unique_lines) / len(lines) < 0.3: # Too repetitive
return False
return True
def clean_text(self, text):
"""Basic text cleaning"""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove very long lines (likely formatting artifacts)
lines = text.split('\n')
cleaned_lines = [line for line in lines if len(line) < 1000]
return '\n'.join(cleaned_lines).strip()
def tokenize_and_chunk(self, text):
"""Tokenize and create fixed-length chunks"""
tokens = self.tokenizer.encode(text)
# Create overlapping chunks
chunks = []
stride = self.max_length // 2 # 50% overlap
for i in range(0, len(tokens), stride):
chunk = tokens[i:i + self.max_length]
if len(chunk) >= self.min_length:
chunks.append(chunk)
return chunks
def process_dataset(self, raw_texts):
"""Process a dataset of raw texts"""
processed_chunks = []
for text in raw_texts:
# Quality filtering
if not self.filter_quality(text):
continue
# Text cleaning
cleaned_text = self.clean_text(text)
# Tokenization and chunking
chunks = self.tokenize_and_chunk(cleaned_text)
processed_chunks.extend(chunks)
return processed_chunks
# Usage example
processor = DataProcessor()
training_data = processor.process_dataset(raw_documents)
print(f"Processed {len(training_data)} training chunks")
๐ LLM Training Best Practices
Data Quality:
- ๐ Rigorous filtering and deduplication
- ๐ Balanced representation across domains
- ๐งน Consistent preprocessing and tokenization
- โ๏ธ Ethical sourcing and bias consideration
Training Process:
- ๐ฏ Clear objectives for each training stage
- ๐ Careful scaling of data, model, and compute
- ๐ Iterative refinement with human feedback
- ๐ก๏ธ Safety and alignment throughout process
๐ก Key Insight: The quality of training data is often more important than quantity. Modern LLMs succeed through careful curation of diverse, high-quality datasets combined with sophisticated training techniques like RLHF.