Training API
The Python client provides two methods for ML training workflows: training_dataset() exports a Lance dataset from a version tag, and create_training_snapshot() creates a version tag and returns its timestamp.
training_dataset()
Python
db.training_dataset(tag: str) -> strExports the table data at the given version tag as a Lance dataset on disk and returns the absolute path to the dataset directory.
tagmust name a version tag created withFOR TRAINING- Returns the path to a Lance dataset directory
- The path is under
<database>/training_exports/<tag>_<ts>/ - Repeat calls with the same tag overwrite the previous export
Python
import galaxdb
db = galaxdb.Database("./data")
# First create a training snapshot via SQL
db.execute("""
CREATE VERSION TAG 'train-v1'
FOR TRAINING
WITH TRAINING PRECISION 'float32'
TRAINING SEED 42
""")
# Export as Lance dataset
path = db.training_dataset("train-v1")
print(path) # /absolute/path/to/data/training_exports/train-v1_1715385600000000/create_training_snapshot()
Python
db.create_training_snapshot(name: str, seed: int | None = None) -> intCreates a FOR TRAINING version tag pinned at the latest committed row and returns the version timestamp (uint64). This is a convenience wrapper around CREATE VERSION TAG ... FOR TRAINING.
Python
import galaxdb
db = galaxdb.Database("./data")
# Insert some data
db.execute("INSERT INTO docs (id, body) VALUES (1, 'hello world')")
# Create training snapshot
ts = db.create_training_snapshot("train-v1", seed=42)
print(f"Snapshot timestamp: {ts}") # e.g., 1715385600000000
# Export the dataset
path = db.training_dataset("train-v1")
print(f"Dataset at: {path}")Tip
Use
create_training_snapshot when you want to create a snapshot and immediately export it in Python, without writing SQL. Use CREATE VERSION TAG in SQL when you need more control over precision and seed options.PyTorch Integration
The Lance dataset can be loaded into PyTorch with zero-copy memory-mapped access:
Python
import galaxdb
import lance
import torch
from torch.utils.data import DataLoader
db = galaxdb.Database("./data")
# Create snapshot
ts = db.create_training_snapshot("train-v1", seed=42)
path = db.training_dataset("train-v1")
# Load with Lance
ds = lance.dataset(path)
print(f"Rows: {ds.count_rows()}")
print(f"Schema: {ds.schema}")
# Convert to PyTorch IterableDataset
pytorch_ds = ds.to_pytorch()
# Create DataLoader
loader = DataLoader(
pytorch_ds,
batch_size=64,
num_workers=4,
pin_memory=True, # faster GPU transfer
)
# Training loop
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
for batch in loader:
embeddings = batch['body'] # float32 tensor [batch_size, 384]
labels = batch['label'] # int tensor [batch_size]
optimizer.zero_grad()
output = model(embeddings)
loss = criterion(output, labels)
loss.backward()
optimizer.step()Examples
Full training pipeline
Python
import galaxdb
import lance
import torch
def prepare_training_data(db_path: str, tag: str, seed: int = 42):
"""Prepare a training dataset from a GalaxDB database."""
db = galaxdb.Database(db_path)
# Check if we have enough data
count = db.execute("SELECT COUNT(*) FROM training_data")[0]['count']
print(f"Total rows: {count}")
unique = db.execute(
"SELECT COUNT(*) FROM training_data WHERE NOT DUPLICATE"
)[0]['count']
print(f"Unique rows: {unique} ({(1 - int(unique)/int(count))*100:.1f}% duplicates removed)")
# Create snapshot
ts = db.create_training_snapshot(tag, seed=seed)
print(f"Snapshot created at: {ts}")
# Export
path = db.training_dataset(tag)
print(f"Lance dataset at: {path}")
return path
# Usage
path = prepare_training_data("./data", "experiment-001", seed=42)
ds = lance.dataset(path)
loader = torch.utils.data.DataLoader(ds.to_pytorch(), batch_size=32)Comparing dataset versions
Python
import galaxdb
import lance
db = galaxdb.Database("./data")
# Export two versions
path_v1 = db.training_dataset("train-v1")
path_v2 = db.training_dataset("train-v2")
ds_v1 = lance.dataset(path_v1)
ds_v2 = lance.dataset(path_v2)
print(f"v1: {ds_v1.count_rows()} rows")
print(f"v2: {ds_v2.count_rows()} rows")
print(f"Delta: +{ds_v2.count_rows() - ds_v1.count_rows()} rows")