01 Data Extraction
Jupyter notebook from the Co-fitness Predicts Co-inheritance in Bacterial Pangenomes project.
NB01: Data Extraction for Co-fitness Co-inheritance Analysis¶
Extract genome × gene cluster presence matrices, co-fitness pairs, gene coordinates, and phylogenetic distances for 11 target organisms.
Requires BERDL JupyterHub — get_spark_session() must be available.
This notebook is the interactive equivalent of src/extract_data.py.
For batch extraction, run the script directly.
import pandas as pd
import numpy as np
from pathlib import Path
# Import Spark session — works on JupyterHub and locally
try:
get_spark_session
except NameError:
from berdl_notebook_utils.setup_spark_session import get_spark_session
spark = get_spark_session()
print(f"Spark version: {spark.version}")
DATA_DIR = Path('../data')
CONS_DIR = Path('../../conservation_vs_fitness/data')
for subdir in ['genome_cluster_matrices', 'cofit', 'gene_coords', 'phylo_distances']:
(DATA_DIR / subdir).mkdir(parents=True, exist_ok=True)
# Target organisms and their species clades
TARGET_ORGANISMS = {
'Koxy': 's__Klebsiella_michiganensis--RS_GCF_002925905.1',
'Btheta': 's__Bacteroides_thetaiotaomicron--RS_GCF_000011065.1',
'Smeli': 's__Sinorhizobium_meliloti--RS_GCF_017876815.1',
'RalstoniaUW163': 's__Ralstonia_solanacearum--RS_GCF_002251695.1',
'Putida': 's__Pseudomonas_E_alloputida--RS_GCF_021282585.1',
'SyringaeB728a': 's__Pseudomonas_E_syringae_M--RS_GCF_009176725.1',
'Korea': 's__Sphingomonas_koreensis--RS_GCF_002797435.1',
'RalstoniaGMI1000': 's__Ralstonia_pseudosolanacearum--RS_GCF_024925465.1',
'Phaeo': 's__Phaeobacter_inhibens--RS_GCF_000473105.1',
'Ddia6719': 's__Dickeya_dianthicola--RS_GCF_000365305.1',
'pseudo3_N2E3': 's__Pseudomonas_E_fluorescens_E--RS_GCF_001307155.1',
}
# Load shared data
link = pd.read_csv(CONS_DIR / 'fb_pangenome_link.tsv', sep='\t')
link = link[link['orgId'] != 'Dyella79']
org_mapping = pd.read_csv(CONS_DIR / 'organism_mapping.tsv', sep='\t')
print(f"Link table: {len(link):,} rows, {link['orgId'].nunique()} organisms")
print(f"Target organisms: {len(TARGET_ORGANISMS)}")
Spark version: 4.0.1
Link table: 173,582 rows, 43 organisms Target organisms: 11
Step 0: Verify Target Species Pangenome Stats¶
clade_ids = list(TARGET_ORGANISMS.values())
clade_str = "','".join(clade_ids)
stats = spark.sql(f"""
SELECT p.gtdb_species_clade_id,
s.GTDB_species,
p.no_genomes,
p.no_core,
p.no_aux_genome,
p.no_singleton_gene_clusters,
p.no_gene_clusters,
s.mean_intra_species_ANI
FROM kbase_ke_pangenome.pangenome p
JOIN kbase_ke_pangenome.gtdb_species_clade s
ON p.gtdb_species_clade_id = s.gtdb_species_clade_id
WHERE p.gtdb_species_clade_id IN ('{clade_str}')
ORDER BY p.no_genomes DESC
""").toPandas()
# Add orgId column
clade_to_org = {v: k for k, v in TARGET_ORGANISMS.items()}
stats['orgId'] = stats['gtdb_species_clade_id'].map(clade_to_org)
# Count FB genes per organism
fb_counts = link.groupby('orgId').agg(
n_fb_genes=('locusId', 'nunique'),
n_aux_fb=('is_auxiliary', lambda x: (x == True).sum())
).reset_index()
stats = stats.merge(fb_counts, on='orgId', how='left')
print(stats[['orgId', 'GTDB_species', 'no_genomes', 'mean_intra_species_ANI',
'no_gene_clusters', 'n_fb_genes', 'n_aux_fb']].to_string(index=False))
orgId GTDB_species no_genomes mean_intra_species_ANI no_gene_clusters n_fb_genes n_aux_fb
Koxy s__Klebsiella_michiganensis 399 98.57 61735 4965 822
Btheta s__Bacteroides_thetaiotaomicron 287 98.44 65634 4727 1632
Smeli s__Sinorhizobium_meliloti 241 98.93 58199 6123 1365
RalstoniaUW163 s__Ralstonia_solanacearum 141 96.27 23007 4303 867
Putida s__Pseudomonas_E_alloputida 128 97.49 42747 5470 1372
SyringaeB728a s__Pseudomonas_E_syringae_M 126 98.70 29917 5031 723
Korea s__Sphingomonas_koreensis 72 98.13 7633 4116 637
RalstoniaGMI1000 s__Ralstonia_pseudosolanacearum 70 95.97 21463 4812 932
Ddia6719 s__Dickeya_dianthicola 66 99.47 9248 4030 767
Phaeo s__Phaeobacter_inhibens 43 97.75 10948 3823 508
pseudo3_N2E3 s__Pseudomonas_E_fluorescens_E 40 99.66 10861 5535 140
Step 1: Extract Genome × Gene Cluster Presence Matrices¶
For each species, build a binary matrix: rows = genomes, columns = gene clusters.
Only include clusters that FB genes map to (from fb_pangenome_link.tsv).
Performance note: Each organism requires joining gene_genecluster_junction (~1B rows)
with gene (~1B rows). BROADCAST hints on the small filter tables reduce join cost.
Expect ~3-5 min per organism, ~45 min total. Already-cached matrices are skipped.
import time
matrix_summary = []
for orgId, clade_id in TARGET_ORGANISMS.items():
outpath = DATA_DIR / 'genome_cluster_matrices' / f'{orgId}_presence.tsv'
if outpath.exists() and outpath.stat().st_size > 0:
cached = pd.read_csv(outpath, sep='\t', index_col=0)
matrix_summary.append({'orgId': orgId, 'genomes': cached.shape[0],
'clusters': cached.shape[1], 'status': 'cached'})
print(f" [{orgId}] Cached: {cached.shape[0]} genomes x {cached.shape[1]} clusters")
continue
print(f" [{orgId}] Extracting...", flush=True)
t0 = time.time()
# Get target cluster IDs
org_clusters = link[link['gtdb_species_clade_id'] == clade_id]['gene_cluster_id'].unique()
if len(org_clusters) == 0:
org_clusters = link[link['orgId'] == orgId]['gene_cluster_id'].unique()
print(f" Target clusters: {len(org_clusters)}")
if len(org_clusters) == 0:
print(f" WARNING: No clusters for {orgId}, skipping")
continue
# Register small filter tables for BROADCAST joins
cluster_df = spark.createDataFrame([(c,) for c in org_clusters], ['gene_cluster_id'])
cluster_df.createOrReplaceTempView('target_clusters')
genome_ids = spark.sql(f"""
SELECT genome_id FROM kbase_ke_pangenome.genome
WHERE gtdb_species_clade_id = '{clade_id}'
""").toPandas()['genome_id'].tolist()
genome_df = spark.createDataFrame([(g,) for g in genome_ids], ['genome_id'])
genome_df.createOrReplaceTempView('target_genomes')
# Use BROADCAST hints on small tables to avoid shuffle joins
presence = spark.sql("""
SELECT /*+ BROADCAST(tc), BROADCAST(tg) */
DISTINCT g.genome_id, j.gene_cluster_id
FROM kbase_ke_pangenome.gene_genecluster_junction j
JOIN target_clusters tc ON j.gene_cluster_id = tc.gene_cluster_id
JOIN kbase_ke_pangenome.gene g ON j.gene_id = g.gene_id
JOIN target_genomes tg ON g.genome_id = tg.genome_id
""").toPandas()
elapsed = time.time() - t0
print(f" Raw presence rows: {len(presence):,} ({elapsed:.0f}s)")
if len(presence) == 0:
print(f" WARNING: No data for {orgId}")
continue
presence['present'] = 1
matrix = presence.pivot_table(
index='genome_id', columns='gene_cluster_id',
values='present', fill_value=0, aggfunc='max'
)
matrix_summary.append({'orgId': orgId, 'genomes': matrix.shape[0],
'clusters': matrix.shape[1], 'status': 'extracted'})
print(f" Matrix: {matrix.shape[0]} genomes x {matrix.shape[1]} clusters")
matrix.to_csv(outpath, sep='\t')
print("\n=== MATRIX SUMMARY ===")
print(pd.DataFrame(matrix_summary).to_string(index=False))
[Koxy] Cached: 399 genomes x 4942 clusters
[Btheta] Cached: 287 genomes x 4649 clusters
[Smeli] Cached: 241 genomes x 6004 clusters [RalstoniaUW163] Cached: 141 genomes x 4413 clusters
[Putida] Cached: 128 genomes x 5409 clusters [SyringaeB728a] Extracting...
Target clusters: 4999
Raw presence rows: 558,868 (218s)
Matrix: 126 genomes x 4999 clusters
[Korea] Extracting...
Target clusters: 4075
Raw presence rows: 254,177 (211s)
Matrix: 72 genomes x 4075 clusters
[RalstoniaGMI1000] Extracting...
Target clusters: 4723
Raw presence rows: 285,319 (220s)
Matrix: 70 genomes x 4723 clusters
[Phaeo] Extracting...
Target clusters: 3790
Raw presence rows: 145,722 (209s)
Matrix: 43 genomes x 3790 clusters
[Ddia6719] Extracting...
Target clusters: 4694
Raw presence rows: 256,026 (209s)
Matrix: 66 genomes x 4694 clusters
[pseudo3_N2E3] Extracting...
Target clusters: 5513
Raw presence rows: 214,589 (211s)
Matrix: 40 genomes x 5513 clusters
=== MATRIX SUMMARY ===
orgId genomes clusters status
Koxy 399 4942 cached
Btheta 287 4649 cached
Smeli 241 6004 cached
RalstoniaUW163 141 4413 cached
Putida 128 5409 cached
SyringaeB728a 126 4999 extracted
Korea 72 4075 extracted
RalstoniaGMI1000 70 4723 extracted
Phaeo 43 3790 extracted
Ddia6719 66 4694 extracted
pseudo3_N2E3 40 5513 extracted
Step 2: Extract Co-fitness Pairs¶
cofit_summary = []
for orgId in TARGET_ORGANISMS:
outpath = DATA_DIR / 'cofit' / f'{orgId}_cofit.tsv'
if outpath.exists() and outpath.stat().st_size > 0:
cached = pd.read_csv(outpath, sep='\t')
cofit_summary.append({'orgId': orgId, 'pairs': len(cached), 'status': 'cached'})
print(f" [{orgId}] Cached: {len(cached):,} pairs")
continue
print(f" [{orgId}] Extracting...", end='', flush=True)
cofit = spark.sql(f"""
SELECT orgId, locusId, hitId,
CAST(rank AS INT) as rank,
CAST(cofit AS FLOAT) as cofit
FROM kescience_fitnessbrowser.cofit
WHERE orgId = '{orgId}'
ORDER BY locusId, CAST(rank AS INT)
""").toPandas()
print(f" {len(cofit):,} pairs")
cofit.to_csv(outpath, sep='\t', index=False)
cofit_summary.append({'orgId': orgId, 'pairs': len(cofit), 'status': 'extracted'})
print("\n=== COFIT SUMMARY ===")
print(pd.DataFrame(cofit_summary).to_string(index=False))
[Koxy] Extracting...
423,936 pairs
[Btheta] Extracting...
328,455 pairs
[Smeli] Extracting...
528,699 pairs
[RalstoniaUW163] Extracting...
0 pairs [Putida] Extracting...
458,688 pairs
[SyringaeB728a] Extracting...
371,004 pairs
[Korea] Extracting...
230,724 pairs
[RalstoniaGMI1000] Extracting...
0 pairs [Phaeo] Extracting...
192,138 pairs
[Ddia6719] Extracting...
250,488 pairs
[pseudo3_N2E3] Extracting...
507,828 pairs
=== COFIT SUMMARY ===
orgId pairs status
Koxy 423936 extracted
Btheta 328455 extracted
Smeli 528699 extracted
RalstoniaUW163 0 extracted
Putida 458688 extracted
SyringaeB728a 371004 extracted
Korea 230724 extracted
RalstoniaGMI1000 0 extracted
Phaeo 192138 extracted
Ddia6719 250488 extracted
pseudo3_N2E3 507828 extracted
Step 3: Extract Gene Coordinates¶
for orgId in TARGET_ORGANISMS:
outpath = DATA_DIR / 'gene_coords' / f'{orgId}_coords.tsv'
if outpath.exists() and outpath.stat().st_size > 0:
print(f" [{orgId}] Cached")
continue
print(f" [{orgId}] Extracting...", end='', flush=True)
coords = spark.sql(f"""
SELECT orgId, locusId, scaffoldId,
CAST(begin AS INT) as begin,
CAST(end AS INT) as end,
strand
FROM kescience_fitnessbrowser.gene
WHERE orgId = '{orgId}'
ORDER BY scaffoldId, CAST(begin AS INT)
""").toPandas()
print(f" {len(coords):,} genes")
coords.to_csv(outpath, sep='\t', index=False)
[Koxy] Extracting...
5,586 genes [Btheta] Extracting...
4,902 genes [Smeli] Extracting...
6,281 genes [RalstoniaUW163] Extracting...
5,006 genes [Putida] Extracting...
5,661 genes [SyringaeB728a] Extracting...
5,216 genes [Korea] Extracting...
4,245 genes [RalstoniaGMI1000] Extracting...
5,204 genes [Phaeo] Extracting...
3,944 genes [Ddia6719] Extracting...
4,338 genes [pseudo3_N2E3] Extracting...
5,854 genes
Step 4: Extract Phylogenetic Distances¶
clade_str = "','".join(TARGET_ORGANISMS.values())
tree_mapping = spark.sql(f"""
SELECT gtdb_species_clade_id, phylogenetic_tree_id
FROM kbase_ke_pangenome.phylogenetic_tree
WHERE gtdb_species_clade_id IN ('{clade_str}')
""").toPandas()
clade_to_org = {v: k for k, v in TARGET_ORGANISMS.items()}
print(f"Species with phylogenetic trees: {len(tree_mapping)}/{len(TARGET_ORGANISMS)}")
for _, row in tree_mapping.iterrows():
clade_id = row['gtdb_species_clade_id']
tree_id = row['phylogenetic_tree_id']
orgId = clade_to_org.get(clade_id)
if orgId is None:
continue
outpath = DATA_DIR / 'phylo_distances' / f'{orgId}_phylo_distances.tsv'
if outpath.exists() and outpath.stat().st_size > 0:
print(f" [{orgId}] Cached")
continue
print(f" [{orgId}] Extracting...", end='', flush=True)
distances = spark.sql(f"""
SELECT genome1_id, genome2_id, branch_distance
FROM kbase_ke_pangenome.phylogenetic_tree_distance_pairs
WHERE phylogenetic_tree_id = '{tree_id}'
""").toPandas()
print(f" {len(distances):,} pairs")
distances.to_csv(outpath, sep='\t', index=False)
# Save reference genome mapping
ref_genomes = org_mapping[
org_mapping['orgId'].isin(TARGET_ORGANISMS.keys())
][['orgId', 'gtdb_species_clade_id', 'pg_genome_id']].drop_duplicates()
ref_genomes.to_csv(DATA_DIR / 'phylo_distances' / 'reference_genomes.tsv',
sep='\t', index=False)
print(f"\nReference genomes saved: {len(ref_genomes)} rows")
Species with phylogenetic trees: 9/11 [Koxy] Extracting...
79,401 pairs [Btheta] Extracting...
41,041 pairs [Smeli] Extracting...
28,920 pairs [RalstoniaGMI1000] Extracting...
2,415 pairs [RalstoniaUW163] Extracting...
9,870 pairs [Putida] Extracting...
8,128 pairs [SyringaeB728a] Extracting...
7,875 pairs [Ddia6719] Extracting...
2,145 pairs [Korea] Extracting...
2,556 pairs Reference genomes saved: 95 rows
print('=' * 60)
print('NB01 SUMMARY: Data Extraction')
print('=' * 60)
print(f'Target organisms: {len(TARGET_ORGANISMS)}')
print(f'Species with phylogenetic trees: {len(tree_mapping)}')
print(f'\nOutput directories:')
for subdir in ['genome_cluster_matrices', 'cofit', 'gene_coords', 'phylo_distances']:
files = list((DATA_DIR / subdir).glob('*.tsv'))
print(f' {subdir}/: {len(files)} files')
print('=' * 60)
============================================================ NB01 SUMMARY: Data Extraction ============================================================ Target organisms: 11 Species with phylogenetic trees: 9 Output directories: genome_cluster_matrices/: 11 files cofit/: 11 files gene_coords/: 11 files phylo_distances/: 10 files ============================================================