01 Data Extraction
Jupyter notebook from the Ecotype Correlation Analysis project.
Scale Up Ecotype Analysis to More Species¶
Goal¶
Expand the environment-ecotype analysis to many more species by:
- Finding species with good environmental embedding coverage (≥20 embeddings, ≥30% coverage)
- Downsampling large species to maximize phylogenetic diversity
- Extracting data for the expanded species set
Selection Criteria¶
| Parameter | Value | Rationale |
|---|---|---|
| Minimum genomes with embeddings | 20 | Basic statistical power |
| Embedding coverage | ≥30% | Representative sample |
| Maximum genomes per species | 250 | Tractable pairwise computations |
| Downsampling method | Maximize diversity using ANI distances | Preserve full genetic spread |
STEP 1: Find Species with Good Embedding Coverage¶
In [ ]:
# Cell 1: Compute Embedding Coverage per Species
import numpy as np
import pandas as pd
import os
from pyspark.sql.functions import col, count, countDistinct
# Initialize Spark
spark = get_spark_session()
OUTPUT_PATH = "../data"
os.makedirs(OUTPUT_PATH, exist_ok=True)
# Join genome table with embeddings to compute coverage per species
coverage_df = spark.sql("""
SELECT
g.gtdb_species_clade_id,
COUNT(DISTINCT g.genome_id) as n_total,
COUNT(DISTINCT e.genome_id) as n_with_embeddings
FROM kbase_ke_pangenome.genome g
LEFT JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
ON g.genome_id = e.genome_id
GROUP BY g.gtdb_species_clade_id
""")
coverage_df.cache()
coverage_pd = coverage_df.toPandas()
coverage_pd['coverage'] = coverage_pd['n_with_embeddings'] / coverage_pd['n_total']
# Save coverage data
coverage_pd.to_csv(f"{OUTPUT_PATH}/species_embedding_coverage.csv", index=False)
print(f"Saved coverage data for {len(coverage_pd)} species")
# Show species with best coverage (>=20 embeddings, >=30% coverage)
good_coverage = coverage_pd[
(coverage_pd['n_with_embeddings'] >= 20) &
(coverage_pd['coverage'] >= 0.30)
].sort_values('n_with_embeddings', ascending=False)
print(f"\nSpecies with >=20 embeddings AND >=30% coverage: {len(good_coverage)}")
print(good_coverage.head(30))
In [ ]:
# Cell 2: Select Target Species for Expanded Analysis
# Use ALL species meeting criteria:
# - ≥20 genomes with embeddings (for statistical power)
# - ≥30% coverage (representative sample)
TARGET_SPECIES = good_coverage['gtdb_species_clade_id'].tolist()
print(f"Selected {len(TARGET_SPECIES)} species for expanded analysis")
# Save target species list
with open(f"{OUTPUT_PATH}/target_species_expanded.txt", 'w') as f:
for sp in TARGET_SPECIES:
f.write(sp + '\n')
STEP 2: Downsample Large Species (Maximize Phylogenetic Diversity)¶
For species with >250 genomes with embeddings, select representatives that maximize total phylogenetic diversity using ANI-based distances (distance = 100 - ANI).
Note: Using ANI as a proxy for phylogenetic distance until per-species tree data is available.
In [ ]:
# Cell 3: Identify Species Needing Downsampling
MAX_GENOMES = 250 # Maximum genomes per species after downsampling
# All target species from good_coverage (already filtered to ≥20 embeddings, ≥30% coverage)
print(f"Total target species: {len(TARGET_SPECIES)}")
# Show which need downsampling (more than MAX_GENOMES genomes with embeddings)
needs_downsampling = good_coverage[good_coverage['n_with_embeddings'] > MAX_GENOMES]
print(f"\nSpecies needing downsampling (>{MAX_GENOMES} with embeddings): {len(needs_downsampling)}")
if len(needs_downsampling) > 0:
print(needs_downsampling[['gtdb_species_clade_id', 'n_total', 'n_with_embeddings']].head(20))
In [ ]:
# Cell 4: Diversity-Maximizing Downsampling Function
import numpy as np
def downsample_maximize_diversity(species_id, max_genomes=250):
"""
Downsample a species by selecting genomes that MAXIMIZE phylogenetic diversity.
Only considers genomes WITH embeddings (required for the analysis).
Uses ANI-based distances (100 - ANI) as proxy for phylogenetic distance.
Algorithm:
1. Get genomes with embeddings for this species
2. Build distance matrix from genome_ani table (distance = 100 - ANI)
3. Use maximin selection to maximize phylogenetic spread
"""
# Get genomes WITH embeddings for this species
embed_genomes = spark.sql(f"""
SELECT g.genome_id
FROM kbase_ke_pangenome.genome g
JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
ON g.genome_id = e.genome_id
WHERE g.gtdb_species_clade_id = '{species_id}'
""").collect()
genome_ids = [r.genome_id for r in embed_genomes]
n_genomes = len(genome_ids)
short_name = species_id.split('__')[1].split('--')[0]
print(f"{short_name}: {n_genomes} genomes with embeddings")
# If small enough, return all genomes
if n_genomes <= max_genomes:
return genome_ids
# Need to downsample - get ANI matrix
print(f" Building ANI distance matrix for {n_genomes} genomes...")
ani_df = spark.sql(f"""
SELECT genome1_id, genome2_id, ANI
FROM kbase_ke_pangenome.genome_ani
WHERE genome1_id IN ({','.join([f"'{g}'" for g in genome_ids])})
AND genome2_id IN ({','.join([f"'{g}'" for g in genome_ids])})
""").toPandas()
# Build distance matrix (distance = 100 - ANI)
genome_to_idx = {g: i for i, g in enumerate(genome_ids)}
dist_matrix = np.zeros((n_genomes, n_genomes))
for _, row in ani_df.iterrows():
i = genome_to_idx.get(row['genome1_id'])
j = genome_to_idx.get(row['genome2_id'])
if i is not None and j is not None:
dist = 100 - row['ANI']
dist_matrix[i, j] = dist
dist_matrix[j, i] = dist
# Greedy maximin selection
print(f" Selecting {max_genomes} representatives to maximize diversity...")
selected_idx = []
remaining_idx = set(range(n_genomes))
# Start with genome that has max sum of distances (most divergent)
sum_dists = dist_matrix.sum(axis=1)
first = int(np.argmax(sum_dists))
selected_idx.append(first)
remaining_idx.remove(first)
# Iteratively add genome with max minimum distance to selected set
while len(selected_idx) < max_genomes and remaining_idx:
best_idx = None
best_min_dist = -1
for idx in remaining_idx:
min_dist = min(dist_matrix[idx, s] for s in selected_idx)
if min_dist > best_min_dist:
best_min_dist = min_dist
best_idx = idx
if best_idx is None:
break
selected_idx.append(best_idx)
remaining_idx.remove(best_idx)
# Summary
total_diversity = sum(dist_matrix[i, j] for i in selected_idx for j in selected_idx if i < j)
print(f" Selected {len(selected_idx)} genomes")
print(f" Total pairwise ANI-distance: {total_diversity:.2f}")
return [genome_ids[i] for i in selected_idx]
# Test on one species (uncomment to test)
# if len(needs_downsampling) > 0:
# test_species = needs_downsampling.iloc[0]['gtdb_species_clade_id']
# genomes = downsample_maximize_diversity(test_species)
In [ ]:
# Cell 5: Build Final Genome List (All Target Species)
# For each target species, get genomes (downsampled if needed)
# All selected genomes will have embeddings
all_target_genomes = []
species_stats = []
for species_id in TARGET_SPECIES:
species_info = good_coverage[good_coverage['gtdb_species_clade_id'] == species_id].iloc[0]
n_with_embed = species_info['n_with_embeddings']
n_total = species_info['n_total']
if n_with_embed > MAX_GENOMES:
# Downsample large species using diversity-maximizing selection
genomes = downsample_maximize_diversity(species_id, max_genomes=MAX_GENOMES)
else:
# Use all genomes with embeddings
genomes_result = spark.sql(f"""
SELECT g.genome_id
FROM kbase_ke_pangenome.genome g
JOIN kbase_ke_pangenome.alphaearth_embeddings_all_years e
ON g.genome_id = e.genome_id
WHERE g.gtdb_species_clade_id = '{species_id}'
""").collect()
genomes = [r.genome_id for r in genomes_result]
for g in genomes:
all_target_genomes.append({
'genome_id': g,
'gtdb_species_clade_id': species_id
})
species_stats.append({
'species': species_id,
'n_total': n_total,
'n_with_embeddings': n_with_embed,
'n_selected': len(genomes)
})
target_genomes_df = pd.DataFrame(all_target_genomes)
target_genomes_df.to_csv(f"{OUTPUT_PATH}/target_genomes_expanded.csv", index=False)
stats_df = pd.DataFrame(species_stats)
stats_df.to_csv(f"{OUTPUT_PATH}/species_selection_stats.csv", index=False)
print(f"\n=== Summary ===")
print(f"Total species: {len(TARGET_SPECIES)}")
print(f"Total genomes selected: {len(target_genomes_df)}")
print(f"Genomes per species: {stats_df['n_selected'].mean():.0f} mean, {stats_df['n_selected'].min()}-{stats_df['n_selected'].max()} range")
In [ ]:
# Cell 5b: Compute Embedding Diversity per Species
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
# Embedding columns (A00-A63, 64 dimensions)
EMBEDDING_COLS = [f"A{i:02d}" for i in range(64)]
# Get embeddings for all target genomes
target_genome_ids = target_genomes_df['genome_id'].tolist()
# Query embeddings with all dimension columns
embeddings_query = f"""
SELECT genome_id, {', '.join(EMBEDDING_COLS)}
FROM kbase_ke_pangenome.alphaearth_embeddings_all_years
"""
embeddings_df = spark.sql(embeddings_query).filter(
col("genome_id").isin(target_genome_ids)
).toPandas()
print(f"Retrieved {len(embeddings_df)} embeddings")
# Convert embedding columns to numpy array
embeddings_df['embedding_vec'] = embeddings_df[EMBEDDING_COLS].values.tolist()
embeddings_df['embedding_vec'] = embeddings_df['embedding_vec'].apply(np.array)
# Merge with species info
embeddings_df = embeddings_df.merge(
target_genomes_df[['genome_id', 'gtdb_species_clade_id']],
on='genome_id'
)
# Compute embedding diversity per species
diversity_stats = []
for species_id in TARGET_SPECIES:
species_embeddings = embeddings_df[
embeddings_df['gtdb_species_clade_id'] == species_id
]['embedding_vec'].values
if len(species_embeddings) < 2:
continue
# Stack into matrix
emb_matrix = np.vstack(species_embeddings)
# Compute pairwise cosine distances
# cosine distance = 1 - cosine_similarity
cosine_dists = pdist(emb_matrix, metric='cosine')
short_name = species_id.split('__')[1].split('--')[0]
diversity_stats.append({
'species': species_id,
'short_name': short_name,
'n_genomes': len(species_embeddings),
'mean_cosine_dist': np.mean(cosine_dists),
'std_cosine_dist': np.std(cosine_dists),
'min_cosine_dist': np.min(cosine_dists),
'max_cosine_dist': np.max(cosine_dists),
'median_cosine_dist': np.median(cosine_dists)
})
diversity_df = pd.DataFrame(diversity_stats)
diversity_df.to_csv(f"{OUTPUT_PATH}/species_embedding_diversity.csv", index=False)
print(f"Computed embedding diversity for {len(diversity_df)} species")
print(f"\nSummary statistics:")
print(diversity_df[['mean_cosine_dist', 'std_cosine_dist']].describe())
# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(diversity_df['mean_cosine_dist'], bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Mean Pairwise Cosine Distance')
axes[0].set_ylabel('Number of Species')
axes[0].set_title('Distribution of Embedding Diversity Across Species')
axes[1].scatter(diversity_df['n_genomes'], diversity_df['mean_cosine_dist'], alpha=0.5)
axes[1].set_xlabel('Number of Genomes')
axes[1].set_ylabel('Mean Pairwise Cosine Distance')
axes[1].set_title('Embedding Diversity vs Sample Size')
plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/embedding_diversity_distribution.png", dpi=150)
plt.show()
# Show species with lowest diversity (potential concern)
print("\nSpecies with LOWEST embedding diversity (potential concern):")
print(diversity_df.nsmallest(10, 'mean_cosine_dist')[['short_name', 'n_genomes', 'mean_cosine_dist']])
STEP 3: Extract Data for Expanded Species Set¶
In [ ]:
# Cell 6: Extract Embeddings for Target Genomes
target_genome_ids = target_genomes_df['genome_id'].tolist()
embeddings_df = spark.sql("""
SELECT *
FROM kbase_ke_pangenome.alphaearth_embeddings_all_years
""").filter(col("genome_id").isin(target_genome_ids))
embeddings_pd = embeddings_df.toPandas()
embeddings_pd.to_csv(f"{OUTPUT_PATH}/embeddings_expanded.csv", index=False)
print(f"Saved {len(embeddings_pd)} embeddings")
In [ ]:
# Cell 7: Extract Within-Species ANI (Chunked)
from pyspark.sql.functions import monotonically_increasing_id
import os
ANI_OUTPUT_PATH = f"{OUTPUT_PATH}/ani_expanded"
os.makedirs(ANI_OUTPUT_PATH, exist_ok=True)
# Get ANI for target genomes (within-species pairs only)
ani_df = spark.sql("""
SELECT
a.genome1_id,
a.genome2_id,
a.ANI,
g.gtdb_species_clade_id
FROM kbase_ke_pangenome.genome_ani a
JOIN kbase_ke_pangenome.genome g ON a.genome1_id = g.genome_id
""").filter(
col("genome1_id").isin(target_genome_ids) &
col("genome2_id").isin(target_genome_ids)
)
ani_df.cache()
total_count = ani_df.count()
print(f"ANI pairs: {total_count}")
# Export in chunks
CHUNK_SIZE = 1000000
ani_with_id = ani_df.withColumn("_id", monotonically_increasing_id())
n_chunks = (total_count // CHUNK_SIZE) + 1
for i in range(n_chunks):
start_id = i * CHUNK_SIZE
end_id = (i + 1) * CHUNK_SIZE
chunk = ani_with_id.filter(
(col("_id") >= start_id) & (col("_id") < end_id)
).drop("_id")
chunk_pd = chunk.toPandas()
chunk_pd.to_csv(f"{ANI_OUTPUT_PATH}/ani_chunk_{i:03d}.csv", index=False)
print(f" Saved chunk {i+1}/{n_chunks}: {len(chunk_pd)} rows")
print(f"\nCompleted ANI export to {ANI_OUTPUT_PATH}")
In [ ]:
# Cell 8: Extract Gene Clusters per Genome (Chunked)
CLUSTERS_OUTPUT_PATH = f"{OUTPUT_PATH}/gene_clusters_expanded"
os.makedirs(CLUSTERS_OUTPUT_PATH, exist_ok=True)
# Get gene clusters for target genomes
genome_clusters_df = spark.sql("""
SELECT
g.genome_id,
gg.gene_cluster_id,
gm.gtdb_species_clade_id
FROM kbase_ke_pangenome.gene g
JOIN kbase_ke_pangenome.gene_genecluster_junction gg
ON g.gene_id = gg.gene_id
JOIN kbase_ke_pangenome.genome gm
ON g.genome_id = gm.genome_id
""").filter(col("genome_id").isin(target_genome_ids))
genome_clusters_df.cache()
total_count = genome_clusters_df.count()
print(f"Gene-cluster associations: {total_count}")
# Export in chunks
CHUNK_SIZE = 5000000
clusters_with_id = genome_clusters_df.withColumn("_id", monotonically_increasing_id())
n_chunks = (total_count // CHUNK_SIZE) + 1
for i in range(n_chunks):
start_id = i * CHUNK_SIZE
end_id = (i + 1) * CHUNK_SIZE
chunk = clusters_with_id.filter(
(col("_id") >= start_id) & (col("_id") < end_id)
).drop("_id")
chunk_pd = chunk.toPandas()
chunk_pd.to_csv(f"{CLUSTERS_OUTPUT_PATH}/clusters_chunk_{i:03d}.csv", index=False)
print(f" Saved chunk {i+1}/{n_chunks}: {len(chunk_pd)} rows")
print(f"\nCompleted gene clusters export to {CLUSTERS_OUTPUT_PATH}")
Summary¶
After running this notebook, download the following from the cluster:
species_embedding_coverage.csv- Coverage for all speciestarget_genomes_expanded.csv- Selected genomes (all have embeddings)species_selection_stats.csv- Selection statistics per speciesspecies_embedding_diversity.csv- Embedding diversity per speciesembedding_diversity_distribution.png- Visualization of diversityembeddings_expanded.csv- Environmental embeddingsani_expanded/- Pairwise ANI chunksgene_clusters_expanded/- Gene cluster chunks
Then run local analysis to:
- Review embedding diversity distribution and decide on any flagging criteria
- Compute environment-gene content correlations for all species
- Identify which species show strongest environment-ecotype signal