G

Training API

The Python client provides two methods for ML training workflows: training_dataset() exports a Lance dataset from a version tag, and create_training_snapshot() creates a version tag and returns its timestamp.

training_dataset()

Python
db.training_dataset(tag: str) -> str

Exports the table data at the given version tag as a Lance dataset on disk and returns the absolute path to the dataset directory.

  • tag must name a version tag created with FOR TRAINING
  • Returns the path to a Lance dataset directory
  • The path is under <database>/training_exports/<tag>_<ts>/
  • Repeat calls with the same tag overwrite the previous export
Python
import galaxdb

db = galaxdb.Database("./data")

# First create a training snapshot via SQL
db.execute("""
    CREATE VERSION TAG 'train-v1'
    FOR TRAINING
    WITH TRAINING PRECISION 'float32'
    TRAINING SEED 42
""")

# Export as Lance dataset
path = db.training_dataset("train-v1")
print(path)  # /absolute/path/to/data/training_exports/train-v1_1715385600000000/

create_training_snapshot()

Python
db.create_training_snapshot(name: str, seed: int | None = None) -> int

Creates a FOR TRAINING version tag pinned at the latest committed row and returns the version timestamp (uint64). This is a convenience wrapper around CREATE VERSION TAG ... FOR TRAINING.

Python
import galaxdb

db = galaxdb.Database("./data")

# Insert some data
db.execute("INSERT INTO docs (id, body) VALUES (1, 'hello world')")

# Create training snapshot
ts = db.create_training_snapshot("train-v1", seed=42)
print(f"Snapshot timestamp: {ts}")  # e.g., 1715385600000000

# Export the dataset
path = db.training_dataset("train-v1")
print(f"Dataset at: {path}")

Tip

Use create_training_snapshot when you want to create a snapshot and immediately export it in Python, without writing SQL. Use CREATE VERSION TAG in SQL when you need more control over precision and seed options.

PyTorch Integration

The Lance dataset can be loaded into PyTorch with zero-copy memory-mapped access:

Python
import galaxdb
import lance
import torch
from torch.utils.data import DataLoader

db = galaxdb.Database("./data")

# Create snapshot
ts = db.create_training_snapshot("train-v1", seed=42)
path = db.training_dataset("train-v1")

# Load with Lance
ds = lance.dataset(path)
print(f"Rows: {ds.count_rows()}")
print(f"Schema: {ds.schema}")

# Convert to PyTorch IterableDataset
pytorch_ds = ds.to_pytorch()

# Create DataLoader
loader = DataLoader(
    pytorch_ds,
    batch_size=64,
    num_workers=4,
    pin_memory=True,  # faster GPU transfer
)

# Training loop
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
    for batch in loader:
        embeddings = batch['body']   # float32 tensor [batch_size, 384]
        labels = batch['label']      # int tensor [batch_size]

        optimizer.zero_grad()
        output = model(embeddings)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

Examples

Full training pipeline

Python
import galaxdb
import lance
import torch

def prepare_training_data(db_path: str, tag: str, seed: int = 42):
    """Prepare a training dataset from a GalaxDB database."""
    db = galaxdb.Database(db_path)

    # Check if we have enough data
    count = db.execute("SELECT COUNT(*) FROM training_data")[0]['count']
    print(f"Total rows: {count}")

    unique = db.execute(
        "SELECT COUNT(*) FROM training_data WHERE NOT DUPLICATE"
    )[0]['count']
    print(f"Unique rows: {unique} ({(1 - int(unique)/int(count))*100:.1f}% duplicates removed)")

    # Create snapshot
    ts = db.create_training_snapshot(tag, seed=seed)
    print(f"Snapshot created at: {ts}")

    # Export
    path = db.training_dataset(tag)
    print(f"Lance dataset at: {path}")

    return path

# Usage
path = prepare_training_data("./data", "experiment-001", seed=42)
ds = lance.dataset(path)
loader = torch.utils.data.DataLoader(ds.to_pytorch(), batch_size=32)

Comparing dataset versions

Python
import galaxdb
import lance

db = galaxdb.Database("./data")

# Export two versions
path_v1 = db.training_dataset("train-v1")
path_v2 = db.training_dataset("train-v2")

ds_v1 = lance.dataset(path_v1)
ds_v2 = lance.dataset(path_v2)

print(f"v1: {ds_v1.count_rows()} rows")
print(f"v2: {ds_v2.count_rows()} rows")
print(f"Delta: +{ds_v2.count_rows() - ds_v1.count_rows()} rows")