Multi Species Cog Analysis
Jupyter notebook from the COG Functional Category Analysis project.
Multi-Species COG Functional Category Analysis¶
This notebook analyzes COG functional category distributions across core, auxiliary, and singleton genes for multiple species to identify conserved patterns.
Goals¶
- Run COG analysis on 32 taxonomically diverse species
- Identify universal patterns vs species-specific differences
- Test if novel genes are consistently enriched in mobile elements (L), surface variation (M)
- Test if core genes are consistently enriched in housekeeping functions (J, C, H)
- Examine phylum-level patterns
In [ ]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from get_spark_session import get_spark_session
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
# Configure plotting
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)
In [ ]:
# Initialize Spark session
spark = get_spark_session()
print(f"Spark version: {spark.version}")
In [ ]:
# COG category descriptions
COG_DESCRIPTIONS = {
'J': 'Translation, ribosomal structure',
'A': 'RNA processing and modification',
'K': 'Transcription',
'L': 'Replication, recombination, repair',
'B': 'Chromatin structure',
'D': 'Cell cycle control, division',
'Y': 'Nuclear structure',
'V': 'Defense mechanisms',
'T': 'Signal transduction',
'M': 'Cell wall/membrane biogenesis',
'N': 'Cell motility',
'Z': 'Cytoskeleton',
'W': 'Extracellular structures',
'U': 'Intracellular trafficking',
'O': 'Posttranslational modification, chaperones',
'C': 'Energy production and conversion',
'G': 'Carbohydrate transport and metabolism',
'E': 'Amino acid transport and metabolism',
'F': 'Nucleotide transport and metabolism',
'H': 'Coenzyme transport and metabolism',
'I': 'Lipid transport and metabolism',
'P': 'Inorganic ion transport',
'Q': 'Secondary metabolites biosynthesis',
'R': 'General function prediction only',
'S': 'Function unknown',
'NU': 'Motility and trafficking', # Composite category
}
# Expected enrichments based on N. gonorrhoeae results
EXPECTED_NOVEL_ENRICHED = ['L', 'M', 'NU', 'U', 'E'] # Mobile elements, surface, metabolism
EXPECTED_NOVEL_DEPLETED = ['J', 'C', 'H', 'S', 'D'] # Translation, energy, coenzyme
Step 1: Load sampled species¶
In [ ]:
# Load the stratified sample of species
sampled_species = pd.read_csv('../data/sampled_species_for_cog_analysis.csv')
print(f"Loaded {len(sampled_species)} species for analysis")
print(f"\nPhylum distribution:")
print(sampled_species['phylum'].value_counts())
print(f"\nGenome count range: {sampled_species['no_genomes'].min()}-{sampled_species['no_genomes'].max()}")
sampled_species.head(10)
Step 2: Define query function for COG distributions¶
In [ ]:
def get_cog_distribution(species_id, gene_class='core'):
"""
Query COG category distribution for a specific gene class in a species.
Parameters:
- species_id: GTDB species clade ID
- gene_class: 'core', 'auxiliary', or 'singleton'
Returns:
- DataFrame with COG_category and gene_count
"""
# Set filter conditions based on gene class
if gene_class == 'core':
class_filter = "gc.is_core = 1"
elif gene_class == 'singleton':
class_filter = "gc.is_singleton = 1"
elif gene_class == 'auxiliary':
class_filter = "gc.is_auxiliary = 1 AND gc.is_singleton = 0"
else:
raise ValueError(f"Unknown gene_class: {gene_class}")
query = f"""
SELECT
ann.COG_category,
COUNT(*) as gene_count
FROM kbase_ke_pangenome.gene_cluster gc
JOIN kbase_ke_pangenome.gene_genecluster_junction j
ON gc.gene_cluster_id = j.gene_cluster_id
JOIN kbase_ke_pangenome.eggnog_mapper_annotations ann
ON j.gene_id = ann.query_name
WHERE gc.gtdb_species_clade_id = '{species_id}'
AND {class_filter}
AND ann.COG_category IS NOT NULL
AND ann.COG_category != '-'
GROUP BY ann.COG_category
ORDER BY gene_count DESC
"""
try:
df = spark.sql(query).toPandas()
df['gene_count'] = pd.to_numeric(df['gene_count'], errors='coerce')
return df
except Exception as e:
print(f"Error querying {gene_class} for {species_id}: {e}")
return pd.DataFrame(columns=['COG_category', 'gene_count'])
def analyze_species_cog(species_row):
"""
Run complete COG analysis for a single species.
Returns:
- Dictionary with core, auxiliary, singleton DataFrames and metadata
"""
species_id = species_row['gtdb_species_clade_id']
species_name = species_row['GTDB_species']
# Query all three gene classes
core_df = get_cog_distribution(species_id, 'core')
aux_df = get_cog_distribution(species_id, 'auxiliary')
singleton_df = get_cog_distribution(species_id, 'singleton')
# Add gene class labels
core_df['gene_class'] = 'Core'
aux_df['gene_class'] = 'Auxiliary'
singleton_df['gene_class'] = 'Singleton/Novel'
# Combine
combined = pd.concat([core_df, aux_df, singleton_df], ignore_index=True)
if len(combined) == 0:
return None
# Calculate proportions
class_totals = combined.groupby('gene_class')['gene_count'].sum()
combined['proportion'] = combined.apply(
lambda row: row['gene_count'] / class_totals.get(row['gene_class'], 1) if row['gene_class'] in class_totals.index else 0,
axis=1
)
# Add metadata
combined['species_id'] = species_id
combined['species_name'] = species_name
combined['phylum'] = species_row['phylum']
combined['no_genomes'] = species_row['no_genomes']
return combined
# Test on one species
print("Testing query function on first species...")
test_result = analyze_species_cog(sampled_species.iloc[0])
if test_result is not None:
print(f"Success! Retrieved {len(test_result)} COG category records")
print(test_result.head())
else:
print("Warning: Test query returned no results")
Step 3: Run analysis on all species¶
Note: This may take 10-30 minutes depending on database load. Progress bar shows current status.
In [ ]:
# Run analysis on all species
all_results = []
failed_species = []
print(f"Analyzing {len(sampled_species)} species...\n")
for idx, row in sampled_species.iterrows():
species_name = row['GTDB_species']
print(f"[{idx+1}/{len(sampled_species)}] Processing {species_name}...", end=' ')
result = analyze_species_cog(row)
if result is not None and len(result) > 0:
all_results.append(result)
print(f"✓ ({len(result)} records)")
else:
failed_species.append(species_name)
print("✗ (no data)")
print(f"\n{'='*80}")
print(f"Analysis complete!")
print(f" Successful: {len(all_results)} species")
print(f" Failed: {len(failed_species)} species")
if failed_species:
print(f" Failed species: {', '.join(failed_species[:5])}{'...' if len(failed_species) > 5 else ''}")
In [ ]:
# Combine all results
if all_results:
all_cog_data = pd.concat(all_results, ignore_index=True)
# Save raw results
all_cog_data.to_csv('../data/multi_species_cog_results.csv', index=False)
print(f"Saved {len(all_cog_data)} records to ../data/multi_species_cog_results.csv")
print(f"\nData summary:")
print(f" Total records: {len(all_cog_data)}")
print(f" Species: {all_cog_data['species_name'].nunique()}")
print(f" COG categories: {all_cog_data['COG_category'].nunique()}")
print(f" Total genes analyzed: {all_cog_data['gene_count'].sum():,.0f}")
all_cog_data.head(20)
else:
print("ERROR: No results collected!")
Step 4: Calculate enrichment scores across all species¶
In [ ]:
# Calculate enrichment (Novel - Core proportion) for each species
enrichment_data = []
for species_name in all_cog_data['species_name'].unique():
species_data = all_cog_data[all_cog_data['species_name'] == species_name]
# Pivot to get proportions by gene class
pivot = species_data.pivot_table(
index='COG_category',
columns='gene_class',
values='proportion',
fill_value=0
)
if 'Core' in pivot.columns and 'Singleton/Novel' in pivot.columns:
enrichment = pivot['Singleton/Novel'] - pivot['Core']
for cog_cat in enrichment.index:
enrichment_data.append({
'species_name': species_name,
'phylum': species_data.iloc[0]['phylum'],
'COG_category': cog_cat,
'enrichment': enrichment[cog_cat],
'core_prop': pivot.loc[cog_cat, 'Core'],
'novel_prop': pivot.loc[cog_cat, 'Singleton/Novel']
})
enrichment_df = pd.DataFrame(enrichment_data)
print(f"Calculated enrichment scores for {len(enrichment_df)} species × COG combinations")
enrichment_df.head(20)
Step 5: Identify conserved patterns¶
Which COG categories are consistently enriched/depleted in novel genes across species?
In [ ]:
# Aggregate enrichment across species
cog_summary = enrichment_df.groupby('COG_category').agg({
'enrichment': ['mean', 'std', 'median'],
'species_name': 'count'
}).round(4)
cog_summary.columns = ['mean_enrichment', 'std_enrichment', 'median_enrichment', 'n_species']
cog_summary = cog_summary.reset_index()
cog_summary = cog_summary.sort_values('mean_enrichment', ascending=False)
# Add descriptions
cog_summary['description'] = cog_summary['COG_category'].map(COG_DESCRIPTIONS)
# Calculate consistency (% of species where enrichment has same sign as mean)
def calculate_consistency(cog_cat):
cat_data = enrichment_df[enrichment_df['COG_category'] == cog_cat]
mean_enrich = cat_data['enrichment'].mean()
if mean_enrich > 0:
consistent = (cat_data['enrichment'] > 0).sum()
else:
consistent = (cat_data['enrichment'] < 0).sum()
return consistent / len(cat_data) * 100
cog_summary['consistency_pct'] = cog_summary['COG_category'].apply(calculate_consistency)
print("\n" + "="*80)
print("COG ENRICHMENT SUMMARY (Novel vs Core genes)")
print("="*80)
print("\nTop 10 ENRICHED in novel genes:")
print(cog_summary.head(10)[['COG_category', 'description', 'mean_enrichment', 'consistency_pct', 'n_species']].to_string(index=False))
print("\nTop 10 DEPLETED in novel genes:")
print(cog_summary.tail(10)[['COG_category', 'description', 'mean_enrichment', 'consistency_pct', 'n_species']].to_string(index=False))
# Save summary
cog_summary.to_csv('../data/cog_enrichment_summary.csv', index=False)
print("\nSaved summary to ../data/cog_enrichment_summary.csv")
Step 6: Visualizations¶
In [ ]:
# Plot 1: Heatmap of enrichment across species
pivot_enrichment = enrichment_df.pivot_table(
index='COG_category',
columns='species_name',
values='enrichment',
fill_value=0
)
# Sort by mean enrichment
row_order = cog_summary.sort_values('mean_enrichment', ascending=False)['COG_category']
pivot_enrichment = pivot_enrichment.loc[[cat for cat in row_order if cat in pivot_enrichment.index]]
fig, ax = plt.subplots(figsize=(18, 12))
sns.heatmap(
pivot_enrichment,
cmap='RdBu_r',
center=0,
cbar_kws={'label': 'Enrichment (Novel - Core)'},
ax=ax,
xticklabels=False,
yticklabels=True
)
ax.set_xlabel('Species', fontsize=12)
ax.set_ylabel('COG Category', fontsize=12)
ax.set_title('COG Category Enrichment in Novel Genes Across Species\n(Red = enriched in novel, Blue = enriched in core)', fontsize=14)
plt.tight_layout()
plt.savefig('../data/multi_species_enrichment_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()
print("Heatmap saved to ../data/multi_species_enrichment_heatmap.png")
In [ ]:
# Plot 2: Distribution of enrichment scores for key COG categories
key_cogs = ['L', 'M', 'J', 'C', 'H', 'V', 'E', 'S']
key_data = enrichment_df[enrichment_df['COG_category'].isin(key_cogs)]
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
data=key_data,
x='COG_category',
y='enrichment',
order=sorted(key_cogs),
palette='Set2',
ax=ax
)
ax.axhline(y=0, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('COG Category', fontsize=12)
ax.set_ylabel('Enrichment (Novel - Core)', fontsize=12)
ax.set_title('Distribution of COG Enrichment Across Species\n(Positive = enriched in novel genes)', fontsize=14)
# Add category descriptions
labels = [f"{cog}\n{COG_DESCRIPTIONS.get(cog, '')[:20]}..." for cog in sorted(key_cogs)]
ax.set_xticklabels(labels, fontsize=10)
plt.tight_layout()
plt.savefig('../data/cog_enrichment_distribution.png', dpi=300, bbox_inches='tight')
plt.show()
print("Distribution plot saved to ../data/cog_enrichment_distribution.png")
In [ ]:
# Plot 3: Mean enrichment with error bars
top_n = 15
top_enriched = cog_summary.nlargest(top_n, 'mean_enrichment')
top_depleted = cog_summary.nsmallest(top_n, 'mean_enrichment')
plot_data = pd.concat([top_enriched, top_depleted])
fig, ax = plt.subplots(figsize=(12, 10))
colors = ['coral' if x > 0 else 'skyblue' for x in plot_data['mean_enrichment']]
ax.barh(
range(len(plot_data)),
plot_data['mean_enrichment'],
xerr=plot_data['std_enrichment'],
color=colors,
alpha=0.7,
error_kw={'linewidth': 1, 'ecolor': 'gray'}
)
# Labels with descriptions
labels = [f"{row['COG_category']}: {row['description'][:40]}" for _, row in plot_data.iterrows()]
ax.set_yticks(range(len(plot_data)))
ax.set_yticklabels(labels, fontsize=9)
ax.axvline(x=0, color='black', linestyle='-', linewidth=1)
ax.set_xlabel('Mean Enrichment ± SD (Novel - Core)', fontsize=12)
ax.set_title(f'Top {top_n} Enriched and Depleted COG Categories in Novel Genes\n(Across {len(all_cog_data["species_name"].unique())} species)', fontsize=14)
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig('../data/cog_mean_enrichment.png', dpi=300, bbox_inches='tight')
plt.show()
print("Mean enrichment plot saved to ../data/cog_mean_enrichment.png")
In [ ]:
# Plot 4: Phylum-specific patterns
phylum_enrichment = enrichment_df.groupby(['phylum', 'COG_category'])['enrichment'].mean().reset_index()
pivot_phylum = phylum_enrichment.pivot_table(
index='COG_category',
columns='phylum',
values='enrichment',
fill_value=0
)
# Sort by overall enrichment
row_order = cog_summary.sort_values('mean_enrichment', ascending=False)['COG_category'].head(20)
pivot_phylum = pivot_phylum.loc[[cat for cat in row_order if cat in pivot_phylum.index]]
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(
pivot_phylum,
cmap='RdBu_r',
center=0,
annot=True,
fmt='.2f',
cbar_kws={'label': 'Mean Enrichment'},
ax=ax
)
ax.set_xlabel('Phylum', fontsize=12)
ax.set_ylabel('COG Category', fontsize=12)
ax.set_title('COG Enrichment Patterns by Phylum\n(Top 20 most variable categories)', fontsize=14)
plt.tight_layout()
plt.savefig('../data/cog_enrichment_by_phylum.png', dpi=300, bbox_inches='tight')
plt.show()
print("Phylum-specific plot saved to ../data/cog_enrichment_by_phylum.png")
Step 7: Statistical analysis¶
In [ ]:
# Test if observed patterns match expectations from N. gonorrhoeae
print("\n" + "="*80)
print("HYPOTHESIS TESTING")
print("="*80)
print("\nBased on N. gonorrhoeae, we expected:")
print(f" Enriched in novel: {', '.join(EXPECTED_NOVEL_ENRICHED)}")
print(f" Depleted in novel: {', '.join(EXPECTED_NOVEL_DEPLETED)}")
# Check how many of our expectations are confirmed
top_10_enriched = set(cog_summary.head(10)['COG_category'])
top_10_depleted = set(cog_summary.tail(10)['COG_category'])
enriched_confirmed = [cog for cog in EXPECTED_NOVEL_ENRICHED if cog in top_10_enriched]
depleted_confirmed = [cog for cog in EXPECTED_NOVEL_DEPLETED if cog in top_10_depleted]
print("\nResults across all species:")
print(f" Enriched expectations confirmed: {len(enriched_confirmed)}/{len(EXPECTED_NOVEL_ENRICHED)} ({', '.join(enriched_confirmed)})")
print(f" Depleted expectations confirmed: {len(depleted_confirmed)}/{len(EXPECTED_NOVEL_DEPLETED)} ({', '.join(depleted_confirmed)})")
# Calculate consistency for expected categories
print("\nConsistency of expected patterns:")
for cog in EXPECTED_NOVEL_ENRICHED:
if cog in cog_summary['COG_category'].values:
row = cog_summary[cog_summary['COG_category'] == cog].iloc[0]
print(f" {cog} ({COG_DESCRIPTIONS.get(cog, 'Unknown')}):")
print(f" Mean enrichment: {row['mean_enrichment']:+.3f}")
print(f" Consistency: {row['consistency_pct']:.1f}% of species")
In [ ]:
# Test for phylum-specific differences
print("\n" + "="*80)
print("PHYLUM-SPECIFIC PATTERNS")
print("="*80)
for cog in ['L', 'M', 'J', 'C']:
cog_data = enrichment_df[enrichment_df['COG_category'] == cog]
phylum_means = cog_data.groupby('phylum')['enrichment'].mean().sort_values(ascending=False)
print(f"\n{cog} ({COG_DESCRIPTIONS.get(cog, 'Unknown')}):")
for phylum, mean_enrich in phylum_means.items():
print(f" {phylum:25s}: {mean_enrich:+.3f}")
Step 8: Summary and conclusions¶
In [ ]:
print("\n" + "="*80)
print("ANALYSIS SUMMARY")
print("="*80)
print(f"\nDataset:")
print(f" Species analyzed: {len(all_cog_data['species_name'].unique())}")
print(f" Phyla represented: {len(all_cog_data['phylum'].unique())}")
print(f" Total genes analyzed: {all_cog_data['gene_count'].sum():,.0f}")
print(f" COG categories found: {all_cog_data['COG_category'].nunique()}")
print(f"\nKey findings:")
print(f" 1. Most consistently enriched in novel genes:")
for _, row in cog_summary.head(5).iterrows():
print(f" - {row['COG_category']}: {row['description'][:50]} ({row['consistency_pct']:.0f}% consistent)")
print(f"\n 2. Most consistently depleted in novel genes:")
for _, row in cog_summary.tail(5).iterrows():
print(f" - {row['COG_category']}: {row['description'][:50]} ({row['consistency_pct']:.0f}% consistent)")
print(f"\n 3. Patterns from N. gonorrhoeae:")
if len(enriched_confirmed) >= 3 and len(depleted_confirmed) >= 3:
print(f" ✓ CONFIRMED across species")
else:
print(f" ✗ NOT fully replicated across species")
print(f"\nGenerated files:")
print(f" - ../data/multi_species_cog_results.csv")
print(f" - ../data/cog_enrichment_summary.csv")
print(f" - ../data/multi_species_enrichment_heatmap.png")
print(f" - ../data/cog_enrichment_distribution.png")
print(f" - ../data/cog_mean_enrichment.png")
print(f" - ../data/cog_enrichment_by_phylum.png")