01 Data Extraction
Jupyter notebook from the Within-Species AMR Strain Variation project.
NB01: Data Extraction — Genome × AMR Presence/Absence Matrices¶
Goal: For each eligible species (≥10 genomes, ≥5 AMR genes, ≥1 non-core AMR), extract genome × AMR gene cluster presence/absence matrices and per-genome metadata.
Compute: Spark (BERDL) — estimated 4-8 hours
Outputs:
data/genome_amr_matrices/{species_id}.tsv— binary presence/absence matricesdata/genome_metadata.csv— per-genome environment metadatadata/eligible_species.csv— species that passed selection criteria
In [ ]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from time import time
PROJECT_DIR = Path(os.getcwd()).parent
ATLAS_DIR = PROJECT_DIR.parent / 'amr_pangenome_atlas'
DATA_DIR = PROJECT_DIR / 'data'
MATRIX_DIR = DATA_DIR / 'genome_amr_matrices'
MATRIX_DIR.mkdir(parents=True, exist_ok=True)
print(f"Project: {PROJECT_DIR}")
print(f"Atlas: {ATLAS_DIR}")
print(f"Output: {MATRIX_DIR}")
1. Spark Session¶
In [ ]:
try:
spark = get_spark_session()
except NameError:
sys.path.insert(0, str(PROJECT_DIR.parent.parent / 'scripts'))
from get_spark_session import get_spark_session
spark = get_spark_session()
spark.sql("SELECT 1 AS test").show()
print("Spark session ready.")
2. Species Selection¶
Criteria: no_genomes >= 10, n_amr >= 5, at least 1 non-core AMR gene.
In [ ]:
# Load atlas summaries
species_summary = pd.read_csv(ATLAS_DIR / 'data' / 'amr_species_summary.csv')
amr_census = pd.read_csv(ATLAS_DIR / 'data' / 'amr_census.csv')
print(f"Species summary: {len(species_summary)} species")
print(f"AMR census: {len(amr_census)} gene clusters")
In [ ]:
# Apply selection criteria
eligible = species_summary[
(species_summary['no_genomes'] >= 10) &
(species_summary['n_amr'] >= 5) &
((species_summary['n_aux_amr'] + species_summary['n_sing_amr']) >= 1)
].copy()
print(f"Eligible species: {len(eligible)} (from {len(species_summary)} total)")
print(f"Genome range: {eligible['no_genomes'].min()}-{eligible['no_genomes'].max()}")
print(f"AMR range: {eligible['n_amr'].min()}-{eligible['n_amr'].max()}")
print(f"\nPhylum distribution:")
print(eligible['phylum'].value_counts().head(10))
eligible.to_csv(DATA_DIR / 'eligible_species.csv', index=False)
3. Utility: Chunked Query¶
In [ ]:
def chunked_query(spark, ids, query_template, chunk_size=5000):
"""Run a query with IN clause in chunks to avoid size limits."""
results = []
for i in range(0, len(ids), chunk_size):
chunk = ids[i:i+chunk_size]
id_list = "','".join(str(x) for x in chunk)
query = query_template.format(id_list=f"'{id_list}'")
results.append(spark.sql(query).toPandas())
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
4. Build AMR Cluster Lookup per Species¶
In [ ]:
# Map species -> AMR gene cluster IDs
species_amr_clusters = (
amr_census
.groupby('gtdb_species_clade_id')['gene_cluster_id']
.apply(list)
.to_dict()
)
# Filter to eligible species only
eligible_ids = set(eligible['gtdb_species_clade_id'])
species_amr_clusters = {k: v for k, v in species_amr_clusters.items() if k in eligible_ids}
print(f"Species with AMR clusters: {len(species_amr_clusters)}")
5. Extract Genome × AMR Matrices¶
For each species:
- Get genome IDs from
genometable - Double-filtered query on
gene+gene_genecluster_junction - Pivot to binary presence/absence matrix
- Cache to TSV
In [ ]:
def extract_species_matrix(spark, species_id, amr_cluster_ids, matrix_dir):
"""Extract genome x AMR presence/absence matrix for one species."""
outpath = matrix_dir / f"{species_id.replace('/', '_')}.tsv"
if outpath.exists() and outpath.stat().st_size > 100:
return outpath, 'cached'
t0 = time()
# Get genome IDs for this species
genomes = chunked_query(spark, [species_id], """
SELECT genome_id
FROM kbase_ke_pangenome.genome
WHERE gtdb_species_clade_id IN ({id_list})
""")
if len(genomes) == 0:
return None, 'no_genomes'
genome_ids = genomes['genome_id'].tolist()
# Double-filtered join: gene (genome filter) x junction (AMR cluster filter)
# Process in genome chunks to keep queries manageable
all_hits = []
genome_chunk_size = 500
amr_list = "','".join(str(x) for x in amr_cluster_ids)
for gi in range(0, len(genome_ids), genome_chunk_size):
g_chunk = genome_ids[gi:gi+genome_chunk_size]
g_list = "','".join(str(x) for x in g_chunk)
query = f"""
SELECT DISTINCT g.genome_id, j.gene_cluster_id
FROM kbase_ke_pangenome.gene g
JOIN kbase_ke_pangenome.gene_genecluster_junction j
ON g.gene_id = j.gene_id
WHERE g.genome_id IN ('{g_list}')
AND j.gene_cluster_id IN ('{amr_list}')
"""
hits = spark.sql(query).toPandas()
all_hits.append(hits)
if not all_hits:
return None, 'no_hits'
hits_df = pd.concat(all_hits, ignore_index=True)
if len(hits_df) == 0:
return None, 'no_hits'
# Pivot to binary matrix (genome x AMR cluster)
hits_df['present'] = 1
matrix = hits_df.pivot_table(
index='genome_id', columns='gene_cluster_id',
values='present', fill_value=0, aggfunc='max'
)
# Ensure all AMR clusters appear as columns (even if absent in all genomes)
for cid in amr_cluster_ids:
if cid not in matrix.columns:
matrix[cid] = 0
# Ensure all genomes appear as rows
for gid in genome_ids:
if gid not in matrix.index:
matrix.loc[gid] = 0
matrix = matrix.astype(int)
matrix.to_csv(outpath, sep='\t')
elapsed = time() - t0
return outpath, f'ok ({len(matrix)} genomes x {len(matrix.columns)} AMR, {elapsed:.0f}s)'
In [ ]:
# Run extraction for all eligible species
results = {}
total = len(species_amr_clusters)
for idx, (species_id, amr_ids) in enumerate(sorted(species_amr_clusters.items())):
short_name = species_id.split('--')[0].replace('s__', '')
print(f"[{idx+1}/{total}] {short_name} ({len(amr_ids)} AMR clusters)...", end=' ')
try:
path, status = extract_species_matrix(spark, species_id, amr_ids, MATRIX_DIR)
results[species_id] = status
print(status)
except Exception as e:
results[species_id] = f'ERROR: {e}'
print(f'ERROR: {e}')
# Summary
status_counts = pd.Series(results).apply(lambda x: x.split('(')[0].strip()).value_counts()
print(f"\nExtraction summary:")
print(status_counts)
6. Extract Per-Genome Environment Metadata¶
In [ ]:
# Collect all genome IDs across eligible species
all_genome_ids = set()
for species_id in species_amr_clusters:
matrix_path = MATRIX_DIR / f"{species_id.replace('/', '_')}.tsv"
if matrix_path.exists():
idx_df = pd.read_csv(matrix_path, sep='\t', usecols=[0])
all_genome_ids.update(idx_df.iloc[:, 0].tolist())
print(f"Total genomes across all species: {len(all_genome_ids)}")
In [ ]:
# Extract metadata from ncbi_env table (EAV format: accession, attribute_name, content)
# Strategy: genome -> ncbi_biosample_id -> ncbi_env attributes -> pivot wide
metadata_path = DATA_DIR / 'genome_metadata.csv'
if metadata_path.exists() and metadata_path.stat().st_size > 100:
print(f"Metadata already cached: {metadata_path}")
genome_meta = pd.read_csv(metadata_path)
else:
# Step 1: Get genome_id -> ncbi_biosample_id mapping
biosample_df = chunked_query(spark, list(all_genome_ids), """
SELECT genome_id, ncbi_biosample_id
FROM kbase_ke_pangenome.genome
WHERE genome_id IN ({id_list})
AND ncbi_biosample_id IS NOT NULL
""")
print(f"Genomes with biosample ID: {len(biosample_df)}")
# Step 2: Query ncbi_env (EAV table) for relevant attributes
biosample_ids = biosample_df['ncbi_biosample_id'].dropna().unique().tolist()
target_attrs = ['isolation_source', 'collection_date', 'geo_loc_name', 'host']
attr_list = "','".join(target_attrs)
env_long = chunked_query(spark, biosample_ids, f"""
SELECT accession, attribute_name, content
FROM kbase_ke_pangenome.ncbi_env
WHERE accession IN ({{id_list}})
AND attribute_name IN ('{attr_list}')
""")
print(f"EAV rows retrieved: {len(env_long)}")
# Step 3: Pivot from long (EAV) to wide format
env_wide = env_long.pivot_table(
index='accession', columns='attribute_name',
values='content', aggfunc='first'
).reset_index().rename(columns={'accession': 'ncbi_biosample_id'})
# Step 4: Join back to genome_id
genome_meta = biosample_df.merge(env_wide, on='ncbi_biosample_id', how='left')
# Ensure expected columns exist
for col in target_attrs:
if col not in genome_meta.columns:
genome_meta[col] = np.nan
genome_meta = genome_meta[['genome_id', 'isolation_source', 'collection_date',
'geo_loc_name', 'host']]
genome_meta.to_csv(metadata_path, index=False)
print(f"Metadata: {len(genome_meta)} genomes")
print(f" isolation_source: {genome_meta['isolation_source'].notna().sum()} non-null")
print(f" collection_date: {genome_meta['collection_date'].notna().sum()} non-null")
print(f" geo_loc_name: {genome_meta['geo_loc_name'].notna().sum()} non-null")
print(f" host: {genome_meta['host'].notna().sum()} non-null")
7. Validation: Spot-Check¶
In [ ]:
# Spot-check 3 species: compare matrix gene counts against direct Spark query
spot_check_species = list(species_amr_clusters.keys())[:3]
for species_id in spot_check_species:
short_name = species_id.split('--')[0].replace('s__', '')
matrix_path = MATRIX_DIR / f"{species_id.replace('/', '_')}.tsv"
if not matrix_path.exists():
print(f"{short_name}: SKIP (no matrix)")
continue
matrix = pd.read_csv(matrix_path, sep='\t', index_col=0)
# Count AMR genes per genome from matrix
matrix_counts = matrix.sum(axis=1).describe()
# Direct Spark count for first genome
first_genome = matrix.index[0]
amr_ids = species_amr_clusters[species_id]
amr_list = "','".join(str(x) for x in amr_ids)
direct = spark.sql(f"""
SELECT COUNT(DISTINCT j.gene_cluster_id) as n_amr
FROM kbase_ke_pangenome.gene g
JOIN kbase_ke_pangenome.gene_genecluster_junction j ON g.gene_id = j.gene_id
WHERE g.genome_id = '{first_genome}'
AND j.gene_cluster_id IN ('{amr_list}')
""").toPandas()
matrix_val = matrix.loc[first_genome].sum()
direct_val = direct['n_amr'].iloc[0]
match = 'MATCH' if matrix_val == direct_val else 'MISMATCH'
print(f"{short_name}: matrix={matrix_val}, direct={direct_val} -> {match}")
print(f" Matrix shape: {matrix.shape}, AMR/genome stats: mean={matrix_counts['mean']:.1f}, std={matrix_counts['std']:.1f}")
In [ ]:
# Summary statistics
matrix_files = list(MATRIX_DIR.glob('*.tsv'))
print(f"\nTotal species matrices: {len(matrix_files)}")
sizes = []
for f in matrix_files:
df = pd.read_csv(f, sep='\t', index_col=0)
sizes.append({'species': f.stem, 'genomes': df.shape[0], 'amr_genes': df.shape[1]})
size_df = pd.DataFrame(sizes)
print(f"Genomes: {size_df['genomes'].sum()} total, {size_df['genomes'].median():.0f} median/species")
print(f"AMR genes: {size_df['amr_genes'].median():.0f} median/species")
print(f"\nTop 10 by genome count:")
print(size_df.nlargest(10, 'genomes').to_string(index=False))