ML · Python · Training · May 2024

GPT Trainer

A reusable framework for training GPT-2 style transformers on any PDF or text dataset. Built from scratch in PyTorch with tiktoken tokenization, loss tracking, and checkpoint management. Demonstrated on Bram Stoker's Dracula.

PyTorch Transformers Python tiktoken AdamW
GPT Trainer

Overview

GPT Trainer is a command-line tool for training GPT-2 style language models on arbitrary text or PDF datasets. The goal was to build the full training pipeline from scratch — not a wrapper around a pretrained model — to develop genuine understanding of how transformers learn.

The framework emphasizes reproducibility: fixed tokenization with tiktoken, configurable context window size, checkpointing at every milestone, and visual loss curves that make it easy to diagnose training behavior. It was validated on Bram Stoker's Dracula, which later became the foundation for the Dracula AI Agent.

This project is the engine underneath the Dracula AI Agent. It was extracted as a standalone reusable tool so any text dataset can be dropped in and trained with a single command.

Architecture

🧠
Transformer Model
  • GPT-2 style with token + positional embeddings
  • Multi-head self-attention blocks
  • Cross-entropy loss, AdamW optimizer
  • Configurable depth, heads, and context size
📊
Data Pipeline
  • PDF and plaintext ingestion
  • tiktoken GPT-2 tokenization
  • Context-window chunking with stride
  • Train/validation split via DataLoader
🧪
Training + Eval
  • Multi-epoch loop with configurable steps
  • Periodic validation loss evaluation
  • Token throughput tracking
  • Qualitative sampling between epochs
💾
Checkpointing
  • Save at milestones with torch.save
  • Resume training from any checkpoint
  • Inference-only load path

Features

📚
Any text or PDF dataset
Drop a PDF or .txt file into data/ and the pipeline handles everything — extraction, tokenization, batching.
🔠
GPT-2 compatible tokenization
tiktoken ensures the vocabulary is consistent with GPT-2, making the model interoperable with standard tooling.
🔁
Flexible training loop
Pure PyTorch — no Trainer API abstraction. Full control over learning rate, batch size, epochs, and evaluation frequency.
📉
Loss visualization
Matplotlib plots train and validation loss side-by-side, with tokens-per-second throughput tracked throughout.
💾
Checkpoint management
Save at any point. Resume training or load for inference-only. Each checkpoint includes model weights and optimizer state.

Workflow

Add dataset
Place PDF or text file in data/. Supports single and multi-document datasets.
Run training
Execute python gpt_train.py. Configure epochs, batch size, and context window via command-line flags or config file.
Watch loss curves
Matplotlib plots update after each eval pass. Training and validation loss printed to stdout with token throughput.
Sample output
Between epochs, the model generates text from a seed prompt. Useful for qualitative evaluation of training progress.
Save checkpoint
Checkpoint saved automatically at configurable intervals. Resume or load for inference at any point.

Tech Stack

🐍
PythonCore language
🔦
PyTorchDeep learning
🔡
tiktokenGPT-2 tokenization
🧠
Custom TransformerBuilt from scratch
🧪
AdamWOptimizer
📈
MatplotlibLoss visualization
📁
File I/OPDF + text ingestion
💾
torch.save/loadCheckpointing