import json
import time
import gzip
import io
import zipfile
import requests
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional, Union, Tuple
from pathlib import Path
import logging
from urllib.parse import urljoin
import h5py
import anndata as ad
from malva_client.exceptions import MalvaAPIError, AuthenticationError, SearchError, QuotaExceededError
logger = logging.getLogger(__name__)
def _expr_data_from_columnar(expr_col: dict) -> dict:
"""Expand columnar (_fmt: col) expression_data to row-major format.
The server encodes aggregated expression data in a compact columnar layout
(si, ci, v0-v4 as parallel arrays) to minimise JSON payload. This function
reconstructs the row-major 'data' list expected by _convert_to_dataframe.
"""
si = expr_col.get('si', [])
n = len(si)
if n == 0:
return {
'data': [],
'samples': expr_col.get('samples', []),
'cell_types': expr_col.get('cell_types', []),
'columns': expr_col.get('columns', []),
'celltype_sample_counts': expr_col.get('celltype_sample_counts', {}),
}
ci = expr_col.get('ci', [])
v0 = expr_col.get('v0', [])
v1 = expr_col.get('v1', [])
v2 = expr_col.get('v2', [])
v3 = expr_col.get('v3', [])
v4 = expr_col.get('v4', [])
return {
'data': [[si[j], ci[j], v0[j], v1[j], v2[j], v3[j], v4[j]] for j in range(n)],
'samples': expr_col.get('samples', []),
'cell_types': expr_col.get('cell_types', []),
'columns': expr_col.get('columns', []),
'celltype_sample_counts': expr_col.get('celltype_sample_counts', {}),
}
[docs]
class MalvaDataFrame:
"""
Wrapper around pandas DataFrame with sample metadata enrichment and analysis methods
"""
FIELD_MAPPING = {
'specimen_from_organism.organ': 'organ',
'specimen_from_organism.organ_part': 'organ_part',
'specimen_from_organism.diseases': 'disease',
'project.project_core.project_short_name': 'study',
'project.project_core.project_title': 'study_title',
'project.contributors.laboratory': 'laboratory',
'genus_species': 'species',
'development_stage': 'development_stage',
'protocol_harmonized': 'protocol',
'sex': 'sex',
'age': 'age',
'diseases': 'disease_alt',
'uuid': 'sample_uuid'
}
[docs]
def __init__(self, expr_df: pd.DataFrame, client: 'MalvaClient', sample_metadata: Dict[int, Dict] = None):
self.client = client
self._df = expr_df.copy()
self._sample_metadata = sample_metadata or {}
self._enrich_with_metadata()
def _enrich_with_metadata(self):
"""Enrich the dataframe with sample metadata using simplified field names"""
if not self._sample_metadata:
return
# Create metadata DataFrame with renamed columns
metadata_records = []
for sample_id, metadata in self._sample_metadata.items():
record = {'sample_id': sample_id}
# Add metadata with both original and mapped field names
for original_field, value in metadata.items():
# Add original field
record[original_field] = value
# Add mapped field if mapping exists
if original_field in self.FIELD_MAPPING:
mapped_field = self.FIELD_MAPPING[original_field]
record[mapped_field] = value
metadata_records.append(record)
if metadata_records:
metadata_df = pd.DataFrame(metadata_records)
# Merge with expression data
self._df = self._df.merge(
metadata_df,
on='sample_id',
how='left',
suffixes=('', '_metadata')
)
self._convert_to_categorical()
self._extract_celltype_sample_counts()
def _extract_celltype_sample_counts(self):
"""Extract celltype_sample_counts from the raw data if available"""
try:
if hasattr(self, 'raw_data') and self.raw_data:
results = self.raw_data.get('results', {})
for gene_key, gene_data in results.items():
if isinstance(gene_data, dict) and 'expression_data' in gene_data:
expr_data = gene_data['expression_data']
if 'celltype_sample_counts' in expr_data:
self._celltype_sample_counts = expr_data['celltype_sample_counts']
print(f"✓ Extracted sample counts per cell type: {len(self._celltype_sample_counts)} cell types")
return
# If we can't find it, we'll have to live with the fallback method
print("ℹ️ Could not extract true sample counts per cell type")
except Exception as e:
print(f"⚠️ Error extracting sample counts: {e}")
def _convert_to_categorical(self):
"""Convert string/object columns to categorical for faster filtering"""
categorical_candidates = [
'organ', 'organ_part', 'disease', 'species', 'development_stage',
'sex', 'study', 'study_title', 'laboratory', 'protocol', 'cell_type'
]
for col in categorical_candidates:
if col in self._df.columns and self._df[col].dtype == 'object':
# Only convert if the column has a reasonable number of unique values
# (avoid converting columns with mostly unique values)
unique_ratio = self._df[col].nunique() / len(self._df)
if unique_ratio < 0.5: # Less than 50% unique values
self._df[col] = self._df[col].astype('category')
@property
def df(self) -> pd.DataFrame:
"""Access the underlying pandas DataFrame"""
return self._df
def __getattr__(self, name):
"""Delegate pandas DataFrame methods"""
return getattr(self._df, name)
def __getitem__(self, key):
"""Delegate pandas DataFrame indexing"""
return self._df[key]
def __len__(self):
return len(self._df)
[docs]
def filter_by(self, **kwargs) -> 'MalvaDataFrame':
"""
Filter data by any combination of metadata fields (optimized for large datasets)
Uses simplified field names (e.g., 'organ' instead of 'specimen_from_organism.organ')
Args:
**kwargs: Field=value pairs for filtering
Returns:
New MalvaDataFrame with filtered data
Example:
df.filter_by(organ='brain', disease='normal', cell_type='neuron')
df.filter_by(species='Homo sapiens', study='BrainDepressiveDisorder')
"""
if not kwargs:
return MalvaDataFrame(self._df.copy(), self.client, self._sample_metadata)
# Start with all rows as True
mask = pd.Series(True, index=self._df.index)
# Create reverse mapping for convenience
reverse_mapping = {v: k for k, v in self.FIELD_MAPPING.items()}
for field, value in kwargs.items():
# Check if field exists directly
actual_field = field
if field not in self._df.columns:
# Try mapped field
if field in reverse_mapping and reverse_mapping[field] in self._df.columns:
actual_field = reverse_mapping[field]
else:
logging.warning(f"Field '{field}' not found. Available fields: {self.available_filter_fields()}")
continue
# Apply filter using boolean indexing (much faster)
if isinstance(value, (list, tuple)):
# For categorical columns, this is very fast
field_mask = self._df[actual_field].isin(value)
else:
# For categorical columns, this is also very fast
field_mask = self._df[actual_field] == value
# Combine with existing mask using AND operation
mask = mask & field_mask
# Apply the final mask in one operation
filtered_df = self._df[mask].copy()
return MalvaDataFrame(filtered_df, self.client, self._sample_metadata)
EXPRESSION_COLUMN_ALIASES = {
'norm_expr': 'rel',
'kpt_expr': 'exp',
'raw_expr': 'exp',
'fraction_positive': 'pct',
'pct_positive': 'pct',
'raw_kmer_mean': 'raw_kmers',
}
def _resolve_expression_column(self, expr_column: str) -> str:
"""Resolve legacy expression column names to the canonical schema."""
if expr_column in self._df.columns:
return expr_column
alias = self.EXPRESSION_COLUMN_ALIASES.get(expr_column)
if alias and alias in self._df.columns:
return alias
raise ValueError(f"Expression column '{expr_column}' not found")
[docs]
def aggregate_by(self, group_by: Union[str, List[str]],
agg_func: str = 'mean',
expr_column: str = 'rel') -> pd.DataFrame:
"""
Aggregate expression data by specified grouping variables
Args:
group_by: Column name(s) to group by
agg_func: Aggregation function ('mean', 'median', 'sum', 'count', 'std')
expr_column: Expression column to aggregate
Returns:
DataFrame with aggregated results
Example:
df.aggregate_by('cell_type')
df.aggregate_by(['organ', 'cell_type'])
"""
expr_column = self._resolve_expression_column(expr_column)
group_cols = [group_by] if isinstance(group_by, str) else group_by
# Resolve field names
resolved_group_cols = []
reverse_mapping = {v: k for k, v in self.FIELD_MAPPING.items()}
for col in group_cols:
if col in self._df.columns:
resolved_group_cols.append(col)
elif col in reverse_mapping and reverse_mapping[col] in self._df.columns:
resolved_group_cols.append(reverse_mapping[col])
else:
raise ValueError(f"Grouping column '{col}' not found")
# Simple aggregation approach
grouped = self._df.groupby(resolved_group_cols)
# Create result DataFrame step by step
if agg_func == 'mean':
expr_agg = grouped[expr_column].mean()
elif agg_func == 'median':
expr_agg = grouped[expr_column].median()
elif agg_func == 'sum':
expr_agg = grouped[expr_column].sum()
elif agg_func == 'count':
expr_agg = grouped[expr_column].count()
elif agg_func == 'std':
expr_agg = grouped[expr_column].std()
elif agg_func == 'min':
expr_agg = grouped[expr_column].min()
elif agg_func == 'max':
expr_agg = grouped[expr_column].max()
else:
raise ValueError(f"Unsupported aggregation function: {agg_func}")
# Build result DataFrame
result = pd.DataFrame({
f"{agg_func}_{expr_column}": expr_agg,
'n_observations': grouped.size(),
})
# Add additional metrics if columns exist
if 'cell_count' in self._df.columns:
result['total_cells'] = grouped['cell_count'].sum()
else:
result['total_cells'] = grouped.size() # Use observation count as proxy
if 'sample_id' in self._df.columns:
result['n_samples'] = grouped['sample_id'].nunique()
else:
result['n_samples'] = 1 # Default
# Reset index and sort
result = result.reset_index()
result = result.sort_values(f"{agg_func}_{expr_column}", ascending=False)
return result.round(6)
[docs]
def plot_expression_by(self, group_by: str, limit: int = None, sort_by: str = 'mean',
ascending: bool = False, **kwargs):
"""
Plot expression levels grouped by a metadata field
Args:
group_by: Column to group by for plotting
limit: Maximum number of categories to show (shows top N)
sort_by: How to sort categories ('mean', 'median', 'count', 'alphabetical')
ascending: Whether to sort in ascending order (False shows highest first)
**kwargs: Additional arguments passed to matplotlib
"""
try:
import matplotlib.pyplot as plt
import seaborn as sns
except ImportError:
raise ImportError("matplotlib and seaborn required for plotting. Install with: pip install matplotlib seaborn")
if group_by not in self._df.columns:
raise ValueError(f"Column '{group_by}' not found")
expr_column = self._resolve_expression_column('rel')
# Prepare data for plotting
plot_data = self._df.copy()
# Calculate sorting metrics for each category
if sort_by != 'alphabetical':
if sort_by == 'mean':
sort_values = plot_data.groupby(group_by)[expr_column].mean()
elif sort_by == 'median':
sort_values = plot_data.groupby(group_by)[expr_column].median()
elif sort_by == 'count':
sort_values = plot_data.groupby(group_by)[expr_column].count()
else:
raise ValueError("sort_by must be one of: 'mean', 'median', 'count', 'alphabetical'")
# Sort categories
sorted_categories = sort_values.sort_values(ascending=ascending).index.tolist()
else:
# Alphabetical sorting
sorted_categories = sorted(plot_data[group_by].dropna().unique())
if not ascending:
sorted_categories = sorted_categories[::-1]
# Apply limit if specified
if limit is not None and limit > 0:
sorted_categories = sorted_categories[:limit]
plot_data = plot_data[plot_data[group_by].isin(sorted_categories)]
# Create the plot
plt.figure(figsize=kwargs.get('figsize', (12, 6)))
# Remove kwargs that shouldn't go to seaborn
plot_kwargs = {k: v for k, v in kwargs.items() if k not in ['figsize']}
# Create box plot with ordered categories
sns.boxplot(data=plot_data, x=group_by, y=expr_column,
order=sorted_categories, **plot_kwargs)
# Customize plot
plt.xticks(rotation=45, ha='right')
plt.xlabel(group_by.replace('_', ' ').title())
plt.ylabel('Normalized Expression')
# Create informative title
title_parts = [f'Expression Distribution by {group_by.replace("_", " ").title()}']
if limit:
title_parts.append(f'(Top {limit}')
if sort_by != 'alphabetical':
title_parts.append(f'by {sort_by})')
else:
title_parts.append('alphabetically)')
title = ' '.join(title_parts)
plt.title(title)
plt.tight_layout()
# Add summary statistics as text if there are few categories
if len(sorted_categories) <= 10:
stats_text = []
for i, category in enumerate(sorted_categories):
cat_data = plot_data[plot_data[group_by] == category][expr_column]
if not cat_data.empty:
mean_val = cat_data.mean()
count = len(cat_data)
stats_text.append(f'{category}: μ={mean_val:.3f} (n={count})')
# Add text box with statistics
if stats_text and len(stats_text) <= 8: # Only if not too many
textstr = '\n'.join(stats_text)
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
plt.gca().text(0.02, 0.98, textstr, transform=plt.gca().transAxes,
fontsize=8, verticalalignment='top', bbox=props)
return plt.gcf()
[docs]
def plot_expression_summary(self, group_by: str, limit: int = 10,
plot_type: str = 'box', **kwargs):
try:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
except ImportError:
raise ImportError("matplotlib and seaborn required for plotting")
if group_by not in self._df.columns:
raise ValueError(f"Column '{group_by}' not found")
# Get top categories by mean expression
expr_column = self._resolve_expression_column('rel')
top_categories = (self._df.groupby(group_by)[expr_column]
.mean().nlargest(limit).index.tolist())
plot_data = self._df[self._df[group_by].isin(top_categories)]
# Create 1x3 subplot layout
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle(f'Expression Analysis by {group_by.replace("_", " ").title()}',
fontsize=16)
# 1. Box plot (or other plot type)
if plot_type == 'box':
sns.boxplot(data=plot_data, x=group_by, y=expr_column,
order=top_categories, ax=axes[0])
axes[0].set_title('Expression Distribution (Box Plot)')
elif plot_type == 'violin':
sns.violinplot(data=plot_data, x=group_by, y=expr_column,
order=top_categories, ax=axes[0])
axes[0].set_title('Expression Distribution (Violin Plot)')
elif plot_type == 'strip':
sns.stripplot(data=plot_data, x=group_by, y=expr_column,
order=top_categories, ax=axes[0])
axes[0].set_title('Expression Distribution (Strip Plot)')
else: # bar plot
means = plot_data.groupby(group_by)[expr_column].agg(['mean', 'std'])
means = means.reindex(top_categories)
means.plot(kind='bar', y='mean', yerr='std', ax=axes[0],
color='skyblue', capsize=4)
axes[0].set_title('Mean Expression ± SD')
axes[0].legend().remove()
axes[0].tick_params(axis='x', rotation=45)
axes[0].set_xlabel(group_by.replace('_', ' ').title())
axes[0].set_ylabel('Normalized Expression')
# 2. Number of samples per category - FIXED TO USE SERVER DATA
sample_counts_per_category = self._get_true_sample_counts(group_by, top_categories)
sample_counts = pd.Series(sample_counts_per_category)
sample_counts.plot(kind='bar', ax=axes[1], color='lightcoral')
axes[1].set_title('Number of Samples')
axes[1].set_xlabel(group_by.replace('_', ' ').title())
axes[1].set_ylabel('Sample Count')
axes[1].tick_params(axis='x', rotation=45)
# Add value labels on bars
for i, v in enumerate(sample_counts.values):
axes[1].text(i, v + 0.01 * max(sample_counts.values), str(v),
ha='center', va='bottom', fontsize=9)
# 3. Total number of cells per category
if 'cell_count' in plot_data.columns:
cell_counts = plot_data.groupby(group_by)['cell_count'].sum().reindex(top_categories)
ylabel = 'Total Cells'
title = 'Number of Cells'
else:
cell_counts = plot_data.groupby(group_by).size().reindex(top_categories)
ylabel = 'Observations'
title = 'Number of Observations'
cell_counts.plot(kind='bar', ax=axes[2], color='lightgreen')
axes[2].set_title(title)
axes[2].set_xlabel(group_by.replace('_', ' ').title())
axes[2].set_ylabel(ylabel)
axes[2].tick_params(axis='x', rotation=45)
# Add value labels on bars (format large numbers)
for i, v in enumerate(cell_counts.values):
if v >= 1000:
label = f'{v/1000:.1f}K' if v < 1000000 else f'{v/1000000:.1f}M'
else:
label = str(int(v))
axes[2].text(i, v + 0.01 * max(cell_counts.values), label,
ha='center', va='bottom', fontsize=9)
# Format y-axis for large numbers
if max(cell_counts.values) >= 1000:
axes[2].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}K' if x >= 1000 else f'{x:.0f}'))
plt.tight_layout()
# Print summary statistics - FIXED WITH CORRECT SAMPLE COUNTS
print(f"\n📊 Summary for {group_by.replace('_', ' ').title()}:")
print("-" * 50)
for category in top_categories:
cat_data = plot_data[plot_data[group_by] == category]
n_samples = sample_counts_per_category.get(category, cat_data['sample_id'].nunique())
n_cells = cat_data['cell_count'].sum() if 'cell_count' in cat_data.columns else len(cat_data)
mean_expr = cat_data[expr_column].mean()
print(f"{category}: {n_samples} samples, {n_cells:,} cells, μ={mean_expr:.3f}")
return fig
def _get_true_sample_counts(self, group_by: str, categories: List[str]) -> Dict[str, int]:
"""Extract true sample counts from server-provided data"""
# Method 1: Check if we have celltype_sample_counts in the data structure
if hasattr(self, '_sample_metadata') and group_by == 'cell_type':
# Look for the celltype_sample_counts in the original search results
# This is a bit hacky but necessary given the current data structure
# Try to find expression_data with celltype_sample_counts
sample_counts = {}
# Check if any columns contain serialized data with our counts
for col in self._df.columns:
if 'expression_data' in str(col).lower():
continue # Skip for now, this is complex to extract
# Alternative: Check if we stored this somewhere during enrichment
if hasattr(self, '_celltype_sample_counts'):
for category in categories:
sample_counts[category] = self._celltype_sample_counts.get(category, 0)
return sample_counts
# Method 2: Fallback - count unique samples per category (this gives wrong results but is better than nothing)
sample_counts = {}
for category in categories:
cat_data = self._df[self._df[group_by] == category]
sample_counts[category] = cat_data['sample_id'].nunique()
return sample_counts
[docs]
def to_pandas(self) -> pd.DataFrame:
"""Convert to regular pandas DataFrame"""
return self._df.copy()
[docs]
def available_fields(self) -> Dict[str, List[str]]:
"""
Get categorized list of available fields for filtering/grouping
Returns:
Dictionary with categorized field names
"""
all_columns = list(self._df.columns)
# Categorize fields
categories = {
'expression': [col for col in all_columns if col in ['rel', 'exp', 'pct', 'raw_kmers', 'cell_count', 'sample_idx', 'cell_type_idx']],
'basic': [col for col in all_columns if col in ['sample_id', 'cell_type', 'gene_sequence']],
'biological': [col for col in all_columns if col in ['organ', 'organ_part', 'disease', 'species', 'development_stage', 'sex', 'age']],
'study': [col for col in all_columns if col in ['study', 'study_title', 'laboratory', 'protocol']],
'technical': [col for col in all_columns if col in ['chunk_id', 'internal_sample_id', 'num_cells', 'num_data', 'num_kmers', 'sample_uuid']],
'original_fields': [col for col in all_columns if '.' in col] # Keep original nested field names
}
# Remove empty categories
return {k: v for k, v in categories.items() if v}
[docs]
def available_filter_fields(self) -> List[str]:
"""Get simplified list of commonly used fields for filtering"""
common_fields = [
'organ', 'disease', 'species', 'development_stage', 'sex', 'age',
'study', 'laboratory', 'protocol', 'cell_type'
]
return [field for field in common_fields if field in self._df.columns]
[docs]
def field_info(self) -> pd.DataFrame:
"""
Get detailed information about available fields
Returns:
DataFrame with field names, types, unique values count, and examples
"""
info_data = []
for col in self._df.columns:
unique_vals = self._df[col].dropna().unique()
info_data.append({
'field_name': col,
'data_type': str(self._df[col].dtype),
'non_null_count': self._df[col].count(),
'unique_values': len(unique_vals),
'example_values': ', '.join(map(str, unique_vals[:3])) if len(unique_vals) > 0 else 'No data',
'category': self._categorize_field(col)
})
return pd.DataFrame(info_data).sort_values(['category', 'field_name'])
def _categorize_field(self, field_name: str) -> str:
"""Categorize a field for better organization"""
if field_name in ['rel', 'exp', 'pct', 'raw_kmers', 'cell_count', 'sample_idx', 'cell_type_idx']:
return 'expression'
elif field_name in ['sample_id', 'cell_type', 'gene_sequence']:
return 'basic'
elif field_name in ['organ', 'organ_part', 'disease', 'species', 'development_stage', 'sex', 'age']:
return 'biological'
elif field_name in ['study', 'study_title', 'laboratory', 'protocol']:
return 'study'
elif '.' in field_name:
return 'original_nested'
else:
return 'technical'
[docs]
def unique_values(self, field: str, limit: int = 20) -> List:
"""
Get unique values for a specific field
Args:
field: Field name to get unique values for
limit: Maximum number of values to return
Returns:
Sorted list of unique values
"""
if field not in self._df.columns:
# Try to find field in mapping
reverse_mapping = {v: k for k, v in self.FIELD_MAPPING.items()}
if field in reverse_mapping:
field = reverse_mapping[field]
else:
available = self.available_filter_fields()
raise ValueError(f"Field '{field}' not found. Available fields: {available}")
unique_vals = self._df[field].dropna().unique()
sorted_vals = sorted(unique_vals, key=lambda x: str(x))
if len(sorted_vals) > limit:
return sorted_vals[:limit]
return sorted_vals
[docs]
def value_counts(self, field: str, limit: int = 10) -> pd.Series:
"""Get value counts for a field"""
if field not in self._df.columns:
reverse_mapping = {v: k for k, v in self.FIELD_MAPPING.items()}
if field in reverse_mapping:
field = reverse_mapping[field]
else:
raise ValueError(f"Field '{field}' not found")
return self._df[field].value_counts().head(limit)
[docs]
class SearchResult(MalvaDataFrame):
"""Container for search results that extends MalvaDataFrame for direct analysis.
Data loading is lazy: if raw_data contains no 'results' (e.g. it is the
initial POST response or a lightweight status response), the DataFrame is
populated on first access to .df by fetching
/api/expression-data/<job_id>/results from the server.
"""
[docs]
def __init__(self, raw_data: Dict[str, Any], client: 'MalvaClient'):
self.raw_data = raw_data
self.client = client
self._enriched = False
self._job_id = raw_data.get('job_id') if isinstance(raw_data, dict) else None
# Try to build DataFrame immediately; mark for lazy fetch if no data yet.
expr_df = self._convert_to_dataframe()
self._needs_lazy_fetch = expr_df.empty and bool(self._job_id)
super().__init__(expr_df, client, sample_metadata={})
def _ensure_loaded(self):
"""Lazily fetch expression data from the server if not yet loaded."""
if not self._needs_lazy_fetch:
return
self._needs_lazy_fetch = False
aggregate_expression = (
self.raw_data.get('aggregate_expression')
if isinstance(self.raw_data, dict)
else None
)
try:
import time as _time
logger.info(f"Fetching results for job {self._job_id} ...")
t0 = _time.time()
response = self.client._request(
'GET', f'/api/expression-data/{self._job_id}/results'
)
# response may be a dict or a SearchResults wrapper
if hasattr(response, '_data'):
response = response._data
n_genes = len(response.get('results', {})) if isinstance(response, dict) else 0
logger.info(
f"Results parsed in {_time.time() - t0:.1f}s: {n_genes} genes"
)
if aggregate_expression is not None and isinstance(response, dict):
response.setdefault('aggregate_expression', aggregate_expression)
if isinstance(response, dict):
response.setdefault('job_id', self._job_id)
self.raw_data = response
if isinstance(response, dict) and response.get('job_id'):
self._job_id = response.get('job_id')
self._df = self._convert_to_dataframe()
sample_meta = response.get('_sample_metadata', {}) if isinstance(response, dict) else {}
self._sample_metadata = sample_meta
self._enrich_with_metadata()
except Exception as exc:
logger.error(f"Failed to fetch results for job {self._job_id}: {exc}")
@property
def df(self) -> 'pd.DataFrame':
"""Return the expression DataFrame, fetching lazily if needed."""
self._ensure_loaded()
return self._df
def _convert_to_dataframe(self) -> pd.DataFrame:
"""Convert raw search results to DataFrame - keep sample-level aggregates"""
if not self.raw_data or not isinstance(self.raw_data, dict):
return pd.DataFrame()
results = self.raw_data.get('results', {})
if not results:
return pd.DataFrame()
all_data = []
for gene_seq, result in results.items():
if gene_seq.startswith('_') or not isinstance(result, dict):
continue
# Handle the new compact expression_data format
if 'expression_data' in result:
expression_data = result['expression_data']
# Expand columnar format (_fmt: col) to row-major before reading 'data'.
# The /api/expression-data/<job_id>/results endpoint returns columnar layout.
if isinstance(expression_data, dict) and expression_data.get('_fmt') == 'col':
expression_data = _expr_data_from_columnar(expression_data)
data = expression_data.get('data', [])
columns = expression_data.get('columns', [])
samples = expression_data.get('samples', [])
cell_types = expression_data.get('cell_types', [])
if data and columns:
# Keep sample x cell-type aggregate rows and expose the same
# metric columns used by Expression Explorer display modes.
col_idx = {str(name): idx for idx, name in enumerate(columns)}
def _value(row, name, default_idx=None, default=0.0):
idx = col_idx.get(name, default_idx)
if idx is None or idx >= len(row):
return default
return row[idx]
sample_rows = []
for row in data:
if len(row) >= 5:
sample_idx = int(_value(row, 'sample_idx', 0, 0))
cell_type_idx = int(_value(row, 'cell_type_idx', 1, 0))
norm_expr = float(_value(row, 'norm_expr', 2, 0.0))
kpt_expr = float(_value(row, 'kpt_expr', 3, norm_expr))
cell_count = int(_value(row, 'cell_count', 4, 0))
fraction_positive = float(_value(row, 'fraction_positive', 5, 0.0))
raw_kmer_mean = float(_value(row, 'raw_kmer_mean', col_idx.get('raw_expr', 6), kpt_expr))
sample_id = samples[sample_idx] if sample_idx < len(samples) else f"sample_{sample_idx}"
cell_type = cell_types[cell_type_idx] if cell_type_idx < len(cell_types) else f"celltype_{cell_type_idx}"
sample_rows.append({
'sample_idx': sample_idx,
'cell_type_idx': cell_type_idx,
'sample_id': sample_id,
'cell_type': cell_type,
'gene_sequence': gene_seq,
'rel': norm_expr,
'exp': kpt_expr,
'pct': fraction_positive * 100.0,
'raw_kmers': raw_kmer_mean,
'cell_count': cell_count,
'sample_celltype_id': f"{sample_id}_{cell_type}",
})
if sample_rows:
df = pd.DataFrame(sample_rows)
all_data.append(df)
continue
# FALLBACK: Handle old format for backward compatibility
expression_data = result.get('expression_data', {})
if not expression_data:
# Try old direct format
if 'cell' in result and 'expression' in result and 'sample' in result:
df = pd.DataFrame({
'cell_id': result['cell'],
'rel': result['expression'],
'sample_id': result['sample'],
'gene_sequence': gene_seq
})
all_data.append(df)
continue
# Handle old expression_data format
data = expression_data.get('data', [])
columns = expression_data.get('columns', [])
samples = expression_data.get('samples', [])
cell_types = expression_data.get('cell_types', [])
if not data:
continue
# Create DataFrame from compact data
df = pd.DataFrame(data, columns=columns)
# Replace indices with actual values
if 'sample_idx' in df.columns:
df['sample_id'] = df['sample_idx'].map(lambda x: samples[x] if x < len(samples) else x)
if 'cell_type_idx' in df.columns:
df['cell_type'] = df['cell_type_idx'].map(lambda x: cell_types[x] if x < len(cell_types) else 'Unknown')
df['gene_sequence'] = gene_seq
all_data.append(df)
if all_data:
result_df = pd.concat(all_data, ignore_index=True)
if 'norm_expr' in result_df.columns and 'rel' not in result_df.columns:
result_df['rel'] = result_df['norm_expr']
if 'kpt_expr' in result_df.columns and 'exp' not in result_df.columns:
result_df['exp'] = result_df['kpt_expr']
elif 'raw_expr' in result_df.columns and 'exp' not in result_df.columns:
result_df['exp'] = result_df['raw_expr']
if 'fraction_positive' in result_df.columns and 'pct' not in result_df.columns:
result_df['pct'] = result_df['fraction_positive'].astype(float) * 100.0
elif 'pct_positive' in result_df.columns and 'pct' not in result_df.columns:
result_df['pct'] = result_df['pct_positive']
if 'raw_kmer_mean' in result_df.columns and 'raw_kmers' not in result_df.columns:
result_df['raw_kmers'] = result_df['raw_kmer_mean']
legacy_columns = [
'norm_expr', 'kpt_expr', 'raw_expr', 'fraction_positive',
'pct_positive', 'raw_kmer_mean',
]
result_df = result_df.drop(columns=[c for c in legacy_columns if c in result_df.columns])
if 'cell_type' in result_df.columns:
result_df['cell_type'] = result_df['cell_type'].astype('category')
return result_df
else:
return pd.DataFrame()
@property
def job_id(self) -> str:
"""Get the job ID for this search"""
if isinstance(self.raw_data, dict) and self.raw_data.get('job_id'):
return str(self.raw_data.get('job_id'))
return str(self._job_id or '')
@property
def status(self) -> str:
"""Get the current status of the search"""
return self.raw_data.get('status', 'unknown')
@property
def results(self) -> Dict[str, Any]:
"""Get the raw search results"""
self._ensure_loaded()
return self.raw_data.get('results', {})
def __len__(self) -> int:
self._ensure_loaded()
return len(self._df)
@property
def total_cells(self) -> int:
"""Get total number of cells found"""
self._ensure_loaded()
if 'cell_count' in self._df.columns:
return int(self._df['cell_count'].sum())
else:
# Fallback to counting combinations
return len(self._df)
def __repr__(self) -> str:
self._ensure_loaded()
if self._df.empty:
return "SearchResult: No data found"
lines = []
lines.append("🔬 Malva Search Results")
lines.append("=" * 50)
# FIXED: Always try to get actual cell count first
if 'cell_count' in self._df.columns and not self._df['cell_count'].isna().all():
total_cells = int(self._df['cell_count'].sum())
lines.append(f"📊 Total cells: {total_cells:,}")
lines.append(f"📊 Sample/cell_type combinations: {len(self._df)}")
else:
# Fallback: count rows but label correctly
lines.append(f"📊 Data points: {len(self._df)}")
lines.append("📊 Individual cell counts not available")
lines.append(f"🧬 Genes/sequences: {self._df['gene_sequence'].nunique() if 'gene_sequence' in self._df.columns else 'N/A'}")
lines.append(f"🧪 Samples: {self._df['sample_id'].nunique() if 'sample_id' in self._df.columns else 'N/A'}")
lines.append(f"🔬 Cell types: {self._df['cell_type'].nunique() if 'cell_type' in self._df.columns else 'N/A'}")
# Expression stats remain the same (these are per sample/cell_type aggregate)
if 'rel' in self._df.columns:
lines.append(f"📈 Expression range: {self._df['rel'].min():.3f} - {self._df['rel'].max():.3f}")
lines.append(f"📊 Mean expression: {self._df['rel'].mean():.3f}")
lines.append("")
# Metadata status
if self._enriched:
lines.append("✅ Enriched with sample metadata")
biological_fields = [f for f in ['organ', 'disease', 'species', 'study'] if f in self._df.columns]
if biological_fields:
lines.append(f"🏷️ Available metadata: {', '.join(biological_fields)}")
else:
lines.append("ℹ️ Basic expression data only")
lines.append("💡 Run .enrich_with_metadata() to add sample metadata for filtering by:")
lines.append(" • Organ, disease, species")
lines.append(" • Study, laboratory, protocol")
lines.append(" • Age, sex, development stage")
lines.append("")
lines.append("🔍 Available methods:")
lines.append(" • .filter_by(organ='brain', disease='normal')")
lines.append(" • .aggregate_by('cell_type')")
lines.append(" • .plot_expression_by('organ')")
lines.append(" • .available_filter_fields()")
return "\n".join(lines)
def __str__(self) -> str:
"""String representation"""
return self.__repr__()
def _repr_html_(self) -> str:
"""HTML representation for Jupyter notebooks"""
self._ensure_loaded()
if self._df.empty:
return "<div><strong>SearchResult:</strong> No data found</div>"
html = []
html.append('<div style="border: 1px solid #ddd; padding: 10px; border-radius: 5px; font-family: monospace;">')
html.append('<h3 style="margin-top: 0;">🔬 Malva Search Results</h3>')
# Stats table
html.append('<table style="border-collapse: collapse; margin: 10px 0;">')
html.append(f'<tr><td><strong>Total cells:</strong></td><td>{len(self._df):,}</td></tr>')
html.append(f'<tr><td><strong>Genes/sequences:</strong></td><td>{self._df["gene_sequence"].nunique() if "gene_sequence" in self._df.columns else "N/A"}</td></tr>')
html.append(f'<tr><td><strong>Samples:</strong></td><td>{self._df["sample_id"].nunique() if "sample_id" in self._df.columns else "N/A"}</td></tr>')
html.append(f'<tr><td><strong>Cell types:</strong></td><td>{self._df["cell_type"].nunique() if "cell_type" in self._df.columns else "N/A"}</td></tr>')
if 'rel' in self._df.columns:
html.append(f'<tr><td><strong>Expression range:</strong></td><td>{self._df["rel"].min():.3f} - {self._df["rel"].max():.3f}</td></tr>')
html.append(f'<tr><td><strong>Mean expression:</strong></td><td>{self._df["rel"].mean():.3f}</td></tr>')
html.append('</table>')
# Metadata status
if self._enriched:
html.append('<div style="color: green;">✅ <strong>Enriched with sample metadata</strong></div>')
else:
html.append('<div style="color: orange;">ℹ️ <strong>Basic expression data only</strong></div>')
html.append('<div style="margin: 5px 0; font-size: 0.9em;">Run <code>.enrich_with_metadata()</code> to add sample metadata for advanced filtering</div>')
html.append('</div>')
return ''.join(html)
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Union, TYPE_CHECKING
if TYPE_CHECKING:
from malva_client.client import MalvaClient
[docs]
class CellExpressionMatrixResult:
"""
Per-cell result returned by MalvaClient.retrieve_cells().
New results are built directly from search and metadata endpoints. The
older ZIP-backed constructor remains supported for compatibility.
"""
[docs]
def __init__(self, archive_path: Optional[Union[str, Path]] = None,
export_record: Optional[Dict[str, Any]] = None,
client: Optional['MalvaClient'] = None,
*,
cells: Optional[pd.DataFrame] = None,
features: Optional[pd.DataFrame] = None,
matrix_entries: Optional[pd.DataFrame] = None,
normalization_factors: Optional[pd.DataFrame] = None,
sample_metadata: Optional[pd.DataFrame] = None,
barcodes: Optional[pd.DataFrame] = None,
job_id: Optional[str] = None,
source: str = 'direct'):
self.archive_path = Path(archive_path) if archive_path is not None else None
self.export_record = export_record or {}
self.client = client
self.export_id = self.export_record.get('export_id')
self.job_id = job_id or self.export_record.get('job_id')
self.source = source or self.export_record.get('source', 'direct')
self.status = self.export_record.get('status', 'completed')
self._cells_df: Optional[pd.DataFrame] = cells.copy() if cells is not None else None
self._features_df: Optional[pd.DataFrame] = features.copy() if features is not None else None
self._normalization_df: Optional[pd.DataFrame] = (
normalization_factors.copy() if normalization_factors is not None else None
)
self._sample_metadata_df: Optional[pd.DataFrame] = (
sample_metadata.copy() if sample_metadata is not None else None
)
self._barcodes_df: Optional[pd.DataFrame] = barcodes.copy() if barcodes is not None else None
self._matrix_entries_df: Optional[pd.DataFrame] = (
matrix_entries.copy() if matrix_entries is not None else None
)
def _read_gz_tsv(self, member: str) -> pd.DataFrame:
if self.archive_path is None:
return pd.DataFrame()
if not self.archive_path.exists():
raise FileNotFoundError(f"Archive not found: {self.archive_path}")
with zipfile.ZipFile(self.archive_path) as zf:
if member not in zf.namelist():
return pd.DataFrame()
with zf.open(member) as raw:
with gzip.GzipFile(fileobj=raw, mode='rb') as gz:
return pd.read_csv(gz, sep='\t')
@property
def cells(self) -> pd.DataFrame:
"""Rows of the matrix: row_index, sample_id, cell_id."""
if self._cells_df is None:
self._cells_df = self._read_gz_tsv('cells.tsv.gz')
return self._cells_df.copy()
@property
def features(self) -> pd.DataFrame:
"""Columns of the matrix: feature_index, job_id, feature, label, and source."""
if self._features_df is None:
self._features_df = self._read_gz_tsv('features.tsv.gz')
return self._features_df.copy()
@property
def normalization_factors(self) -> pd.DataFrame:
"""Per-cell size factors aligned by row_index, when available."""
if self._normalization_df is None:
self._normalization_df = self._read_gz_tsv('normalization_factors.tsv.gz')
return self._normalization_df.copy()
@property
def sample_metadata(self) -> pd.DataFrame:
"""Sample metadata table keyed by sample_id, when available."""
if self._sample_metadata_df is None:
self._sample_metadata_df = self._read_gz_tsv('sample_metadata.tsv.gz')
return self._sample_metadata_df.copy()
@property
def barcodes(self) -> pd.DataFrame:
"""Optional barcode table aligned by row_index."""
if self._barcodes_df is None:
self._barcodes_df = self._read_gz_tsv('barcodes.tsv.gz')
return self._barcodes_df.copy()
@property
def matrix_entries(self) -> pd.DataFrame:
"""
Sparse matrix entries as row_index, feature_index, value.
Values are raw per-cell expression or k-mer hit counts. Missing
row/feature pairs are zero.
"""
if self._matrix_entries_df is not None:
return self._matrix_entries_df.copy()
if self.archive_path is None:
self._matrix_entries_df = pd.DataFrame(columns=['row_index', 'feature_index', 'value'])
return self._matrix_entries_df.copy()
if not self.archive_path.exists():
raise FileNotFoundError(f"Archive not found: {self.archive_path}")
with zipfile.ZipFile(self.archive_path) as zf:
if 'matrix.mtx.gz' not in zf.namelist():
self._matrix_entries_df = pd.DataFrame(columns=['row_index', 'feature_index', 'value'])
return self._matrix_entries_df.copy()
with zf.open('matrix.mtx.gz') as raw:
with gzip.GzipFile(fileobj=raw, mode='rb') as gz:
text = io.TextIOWrapper(gz, encoding='utf-8')
for line in text:
if not line.startswith('%'):
break
try:
entries = pd.read_csv(
text,
sep=r'\s+',
names=['row_index', 'feature_index', 'value'],
engine='python',
)
except pd.errors.EmptyDataError:
entries = pd.DataFrame(columns=['row_index', 'feature_index', 'value'])
if entries.empty:
entries = pd.DataFrame(columns=['row_index', 'feature_index', 'value'])
else:
entries['row_index'] = entries['row_index'].astype('int64')
entries['feature_index'] = entries['feature_index'].astype('int64')
entries['value'] = entries['value'].astype('float32')
self._matrix_entries_df = entries
return self._matrix_entries_df.copy()
[docs]
def get_cell_ids(self, sample_ids: Optional[Union[int, List[int]]] = None) -> pd.DataFrame:
"""
Return sample_id/cell_id pairs from the matrix rows.
Args:
sample_ids: Optional encoded sample ID or list of sample IDs.
"""
cells = self.cells
if sample_ids is None or cells.empty:
return cells
if isinstance(sample_ids, (int, np.integer)):
sample_set = {int(sample_ids)}
else:
sample_set = {int(s) for s in sample_ids}
return cells[cells['sample_id'].isin(sample_set)].copy()
[docs]
def positive_cells(self, feature: Optional[Union[str, int]] = None,
sample_ids: Optional[Union[int, List[int]]] = None) -> pd.DataFrame:
"""
Return cells with non-zero expression in any retrieved feature or one feature.
Args:
feature: Feature label/name or feature_index. If omitted, returns
all cells that are present in the retrieved matrix rows.
sample_ids: Optional encoded sample ID or list of sample IDs.
"""
cells = self.get_cell_ids(sample_ids)
if feature is None or cells.empty:
return cells
features = self.features
if isinstance(feature, (int, np.integer)):
feature_indices = {int(feature)}
else:
feature_text = str(feature)
mask = (
features.get('feature', pd.Series(dtype=object)).astype(str).eq(feature_text)
| features.get('label', pd.Series(dtype=object)).astype(str).eq(feature_text)
)
feature_indices = set(features.loc[mask, 'feature_index'].astype(int).tolist())
if not feature_indices:
raise ValueError(f"Feature not found in result: {feature}")
entries = self.matrix_entries
row_ids = set(entries.loc[entries['feature_index'].isin(feature_indices), 'row_index'].astype(int).tolist())
return cells[cells['row_index'].isin(row_ids)].copy()
[docs]
def to_dataframe(self, normalized: bool = False,
include_sample_metadata: bool = False) -> pd.DataFrame:
"""
Convert the sparse matrix to a long DataFrame.
Args:
normalized: Add normalized_value = value / size_factor when
normalization factors are available.
include_sample_metadata: Merge sample metadata by sample_id.
"""
entries = self.matrix_entries
cells = self.cells
features = self.features
if entries.empty:
columns = ['row_index', 'sample_id', 'cell_id', 'feature_index', 'feature', 'label', 'value']
return pd.DataFrame(columns=columns)
df = entries.merge(cells, on='row_index', how='left')
df = df.merge(features, on='feature_index', how='left')
if normalized:
norm = self.normalization_factors
if not norm.empty and 'size_factor' in norm.columns:
norm = norm[['row_index', 'total_counts', 'size_factor']].copy()
df = df.merge(norm, on='row_index', how='left')
sf = pd.to_numeric(df['size_factor'], errors='coerce')
df['normalized_value'] = df['value'] / sf.replace(0, np.nan)
if include_sample_metadata:
meta = self.sample_metadata
if not meta.empty and 'sample_id' in meta.columns:
df = df.merge(meta, on='sample_id', how='left', suffixes=('', '_sample'))
return df
[docs]
def for_sample(self, sample_id: int, normalized: bool = False,
include_sample_metadata: bool = False) -> pd.DataFrame:
"""Return long matrix entries for one encoded sample ID."""
cells = self.get_cell_ids(int(sample_id))
if cells.empty:
return pd.DataFrame()
row_ids = set(cells['row_index'].astype(int).tolist())
entries = self.matrix_entries
subset = entries[entries['row_index'].isin(row_ids)]
if subset.empty:
return pd.DataFrame()
original_entries = self._matrix_entries_df
try:
self._matrix_entries_df = subset
return self.to_dataframe(
normalized=normalized,
include_sample_metadata=include_sample_metadata,
)
finally:
self._matrix_entries_df = original_entries
[docs]
def to_single_cell_result(self, feature: Optional[Union[str, int]] = None,
sample_ids: Optional[Union[int, List[int]]] = None,
normalized: bool = False) -> 'SingleCellResult':
"""
Convert one retrieved feature to the legacy SingleCellResult shape.
This is useful for existing downstream code that expects columns
cell_id, expression, and sample_id.
"""
features = self.features
if features.empty:
return SingleCellResult({'status': 'completed', 'results': {}}, self.client)
if feature is None:
if len(features) != 1:
raise ValueError("feature is required when the result has multiple features")
feature_index = int(features.iloc[0]['feature_index'])
feature_label = str(features.iloc[0].get('label') or features.iloc[0].get('feature'))
elif isinstance(feature, (int, np.integer)):
feature_index = int(feature)
row = features[features['feature_index'] == feature_index]
feature_label = str(row.iloc[0].get('label') or row.iloc[0].get('feature')) if not row.empty else str(feature)
else:
feature_text = str(feature)
mask = (
features.get('feature', pd.Series(dtype=object)).astype(str).eq(feature_text)
| features.get('label', pd.Series(dtype=object)).astype(str).eq(feature_text)
)
rows = features.loc[mask]
if rows.empty:
raise ValueError(f"Feature not found in result: {feature}")
feature_index = int(rows.iloc[0]['feature_index'])
feature_label = str(rows.iloc[0].get('label') or rows.iloc[0].get('feature'))
df = self.to_dataframe(normalized=normalized)
df = df[df['feature_index'] == feature_index]
if sample_ids is not None:
if isinstance(sample_ids, (int, np.integer)):
sample_set = {int(sample_ids)}
else:
sample_set = {int(s) for s in sample_ids}
df = df[df['sample_id'].isin(sample_set)]
expr_col = 'normalized_value' if normalized and 'normalized_value' in df.columns else 'value'
result_data = {
'status': 'completed',
'results': {
feature_label: {
'cell': df['cell_id'].astype(int).tolist(),
'sample': df['sample_id'].astype(int).tolist(),
'expression': df[expr_col].astype(float).tolist(),
'ncells': int(len(df)),
}
}
}
return SingleCellResult(result_data, self.client)
[docs]
def project(self, dataset_id: str,
sample_ids: Optional[Union[int, List[int]]] = None,
feature: Optional[Union[str, int]] = None,
**kwargs) -> 'CoexpressionResult':
"""
Project retrieved positive cells onto a coexpression index.
Args:
dataset_id: Coexpression index or dataset identifier.
sample_ids: Optional encoded sample ID or IDs to restrict cells.
feature: Optional feature to restrict to cells positive for that
feature. If omitted, uses all retrieved positive cells.
**kwargs: Additional coexpression parameters, such as top_n_genes.
"""
if self.client is None:
raise MalvaAPIError("A MalvaClient is required for projection")
cells = self.positive_cells(feature=feature, sample_ids=sample_ids)
return self.client.project_cells(
cells['cell_id'].astype(int).tolist(),
cells['sample_id'].astype(int).tolist(),
dataset_id=dataset_id,
**kwargs,
)
def __repr__(self) -> str:
n_cells = self.export_record.get('cell_count')
n_features = self.export_record.get('feature_count')
if n_cells is None:
try:
n_cells = len(self.cells)
except Exception:
n_cells = 'unknown'
if n_features is None:
try:
n_features = len(self.features)
except Exception:
n_features = 'unknown'
if self.archive_path is not None:
location = f"archive_path='{self.archive_path}'"
else:
location = f"job_id='{self.job_id}', source='{self.source}'"
return f"CellExpressionMatrixResult(cells={n_cells}, features={n_features}, {location})"
[docs]
class SingleCellResult:
"""
Represents search results at the single cell level (not aggregated by cell type)
"""
[docs]
def __init__(self, results_data: Dict[str, Any], client: Optional['MalvaClient'] = None):
"""
Initialize SingleCellResult
Args:
results_data: Raw results from the API
client: MalvaClient instance for metadata enrichment
"""
self.raw_data = results_data
self.client = client
self._processed_data = None # <-- This starts as None
# Extract basic info
self.job_id = results_data.get('job_id')
self.status = results_data.get('status', 'unknown')
# Process results if available and completed
if self.status == 'completed' and 'results' in results_data:
self._process_results()
def _process_results(self):
"""Process raw results into structured data"""
# Check if we have a nested results structure
if 'results' in self.raw_data and 'results' in self.raw_data['results']:
# We have nested results - this is the actual data
results = self.raw_data['results']['results']
print(f"DEBUG: Found nested results structure")
elif 'results' in self.raw_data:
# Try the direct results first
results = self.raw_data.get('results', {})
print(f"DEBUG: Using direct results structure")
else:
results = {}
print(f"DEBUG: No results found")
# Now look for seq_1 or similar keys
result_keys = [k for k in results.keys() if k.startswith('seq') and not k.startswith('_')]
if not result_keys:
# Try to find any dict with 'cell' key
for key, value in results.items():
if isinstance(value, dict) and 'cell' in value:
result_keys = [key]
break
if not result_keys:
print(f"DEBUG: No valid result keys found in: {list(results.keys())}")
self._processed_data = {
'cells': [],
'expression': [],
'samples': [],
'query_gene': None,
'metadata': {'error': 'No sequence data found'}
}
return
first_key = result_keys[0]
result_data = results[first_key]
print(f"DEBUG: Processing key '{first_key}' with type {type(result_data)}")
if not isinstance(result_data, dict):
print(f"DEBUG: Result data is not a dict: {type(result_data)}")
self._processed_data = {
'cells': [],
'expression': [],
'samples': [],
'query_gene': None,
'metadata': {'error': f'Invalid result data type: {type(result_data)}'}
}
return
# Extract the arrays
cells = result_data.get('cell', [])
expressions = result_data.get('expression', [])
samples = result_data.get('sample', [])
print(f"DEBUG: Found {len(cells)} cells, {len(expressions)} expressions, {len(samples)} samples")
# Store processed data
self._processed_data = {
'cells': cells,
'expression': expressions,
'samples': samples,
'query_gene': result_data.get('sequence', first_key),
'metadata': {
'total_cells': result_data.get('ncells', len(cells)),
'unique_samples': len(set(samples)) if samples else 0,
'sequence_length': result_data.get('sequence_length'),
'query_info': self.raw_data.get('query', {}),
'searches_remaining': self.raw_data.get('searches_remaining')
}
}
print(f"✓ Processed {len(cells)} cells from {len(set(samples))} samples")
@property
def is_completed(self) -> bool:
"""Check if the search is completed"""
return self.status == 'completed'
@property
def is_pending(self) -> bool:
"""Check if the search is still pending"""
return self.status == 'pending'
@property
def has_results(self) -> bool:
"""Check if results are available"""
return self._processed_data is not None and len(self._processed_data['cells']) > 0
@property
def cell_count(self) -> int:
"""Get total number of cells in results"""
if not self._processed_data:
return 0
return len(self._processed_data['cells'])
@property
def sample_count(self) -> int:
"""Get number of unique samples in results"""
if not self._processed_data:
return 0
return len(set(self._processed_data['samples']))
@property
def query_gene(self) -> Optional[str]:
"""Get the queried gene symbol or sequence"""
if not self._processed_data:
return None
return self._processed_data.get('query_gene')
[docs]
def get_cell_ids(self) -> List[int]:
"""Get list of all cell IDs"""
if not self.has_results:
return []
return self._processed_data['cells']
[docs]
def get_expression_values(self) -> List[float]:
"""Get list of all expression values"""
if not self.has_results:
return []
return self._processed_data['expression']
[docs]
def get_sample_ids(self) -> List[int]:
"""Get list of all sample IDs"""
if not self.has_results:
return []
return self._processed_data['samples']
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""
Convert results to a pandas DataFrame
Returns:
DataFrame with columns: cell_id, expression, sample_id
"""
if not self.has_results:
return pd.DataFrame(columns=['cell_id', 'expression', 'sample_id'])
return pd.DataFrame({
'cell_id': self._processed_data['cells'],
'expression': self._processed_data['expression'],
'sample_id': self._processed_data['samples']
})
[docs]
def filter_by_expression(self, min_expression: float = 0, max_expression: float = float('inf')) -> 'SingleCellResult':
"""
Filter results by expression thresholds
Args:
min_expression: Minimum expression value
max_expression: Maximum expression value
Returns:
New SingleCellResult with filtered data
"""
if not self.has_results:
return self
df = self.to_dataframe()
filtered_df = df[(df['expression'] >= min_expression) & (df['expression'] <= max_expression)]
# Create new result object with filtered data
filtered_data = {
'status': 'completed',
'results': {
'seq_1': {
'cell': filtered_df['cell_id'].tolist(),
'expression': filtered_df['expression'].tolist(),
'sample': filtered_df['sample_id'].tolist(),
'ncells': len(filtered_df)
}
}
}
return SingleCellResult(filtered_data, self.client)
[docs]
def filter_by_samples(self, sample_ids: List[int]) -> 'SingleCellResult':
"""
Filter results to specific samples
Args:
sample_ids: List of sample IDs to keep
Returns:
New SingleCellResult with filtered data
"""
if not self.has_results:
return self
df = self.to_dataframe()
filtered_df = df[df['sample_id'].isin(sample_ids)]
# Create new result object with filtered data
filtered_data = {
'status': 'completed',
'results': {
'seq_1': {
'cell': filtered_df['cell_id'].tolist(),
'expression': filtered_df['expression'].tolist(),
'sample': filtered_df['sample_id'].tolist(),
'ncells': len(filtered_df)
}
}
}
return SingleCellResult(filtered_data, self.client)
[docs]
def get_top_expressing_cells(self, n: int = 100) -> 'SingleCellResult':
"""
Get top N expressing cells
Args:
n: Number of top cells to return
Returns:
New SingleCellResult with top expressing cells
"""
if not self.has_results:
return self
df = self.to_dataframe()
top_df = df.nlargest(n, 'expression')
# Create new result object with top cells
filtered_data = {
'status': 'completed',
'results': {
'seq_1': {
'cell': top_df['cell_id'].tolist(),
'expression': top_df['expression'].tolist(),
'sample': top_df['sample_id'].tolist(),
'ncells': len(top_df)
}
}
}
return SingleCellResult(filtered_data, self.client)
[docs]
def aggregate_by_sample(self) -> pd.DataFrame:
"""
Aggregate expression data by sample
Returns:
DataFrame with sample-level statistics
"""
if not self.has_results:
return pd.DataFrame(columns=['sample_id', 'mean_expression', 'cell_count', 'max_expression'])
df = self.to_dataframe()
aggregated = df.groupby('sample_id').agg({
'expression': ['mean', 'max', 'std', 'count'],
'cell_id': 'count'
}).round(4)
# Flatten column names
aggregated.columns = ['mean_expression', 'max_expression', 'std_expression', 'expression_count', 'cell_count']
aggregated = aggregated.reset_index()
return aggregated
[docs]
def get_expression_stats(self) -> Dict[str, float]:
"""
Get basic statistics about expression values
Returns:
Dictionary with expression statistics
"""
if not self.has_results:
return {}
expression_values = np.array(self._processed_data['expression'])
return {
'mean': float(np.mean(expression_values)),
'median': float(np.median(expression_values)),
'std': float(np.std(expression_values)),
'min': float(np.min(expression_values)),
'max': float(np.max(expression_values)),
'q25': float(np.percentile(expression_values, 25)),
'q75': float(np.percentile(expression_values, 75)),
'non_zero_cells': int(np.sum(expression_values > 0)),
'total_cells': len(expression_values)
}
[docs]
def save_to_csv(self, filename: str, include_metadata: bool = True):
"""
Save results to CSV file
Args:
filename: Output filename
include_metadata: Whether to include metadata enrichment
"""
if include_metadata and self.client:
df = self.enrich_with_metadata()
else:
df = self.to_dataframe()
df.to_csv(filename, index=False)
print(f"✓ Results saved to {filename} ({len(df)} cells)")
def __repr__(self) -> str:
if not self.has_results:
if self.status == 'pending':
return f"SingleCellResult(status='pending', job_id='{self.job_id}')"
return "SingleCellResult(status='no results', cells=0)"
return (f"SingleCellResult(sequence_length={self._processed_data['metadata'].get('sequence_length', 'N/A')}, "
f"cells={self.cell_count:,}, samples={self.sample_count})")
def __len__(self) -> int:
return self.cell_count
[docs]
class CoverageResult:
"""
Represents genomic coverage data from the Malva genome browser.
Coverage data is organized as a matrix with genomic positions as rows
and cell types as columns. Each cell contains a coverage value.
"""
[docs]
def __init__(self, raw_data: Dict[str, Any], client: Optional['MalvaClient'] = None):
"""
Initialize CoverageResult
Args:
raw_data: Raw coverage data from the API
client: MalvaClient instance for follow-up requests
"""
self.raw_data = raw_data
self.client = client
self.job_id = raw_data.get('job_id', '')
self.region = raw_data.get('region', '')
self.positions = raw_data.get('positions', [])
self.cell_types = raw_data.get('cell_types', [])
self.coverage_matrix = raw_data.get('coverage_matrix', [])
self.sample_celltype_coverage = raw_data.get('sample_celltype_coverage', {})
self.total_windows = raw_data.get('total_windows', len(self.positions))
def to_long_dataframe(self) -> pd.DataFrame:
"""Return compact coverage rows by position, sample, and cell type.
The server does not return single-cell coverage here. Rows are aggregate
sample x cell-type values for each genomic position/probe, with raw
signal, positive-cell count, and mean signal per positive cell.
"""
agg = self.sample_celltype_coverage or {}
rows = agg.get('data') or []
if not rows:
return pd.DataFrame(columns=[
'position', 'position_idx', 'sample_id', 'sample_idx',
'cell_type', 'cell_type_idx', 'raw_signal', 'cell_count',
'mean_signal',
])
positions = agg.get('positions') or self.positions or []
samples = agg.get('samples') or []
cell_types = agg.get('cell_types') or self.cell_types or []
records = []
for row in rows:
if len(row) < 6:
continue
pi, si, ci = int(row[0]), int(row[1]), int(row[2])
records.append({
'position': positions[pi] if pi < len(positions) else pi,
'position_idx': pi,
'sample_id': samples[si] if si < len(samples) else si,
'sample_idx': si,
'cell_type': cell_types[ci] if ci < len(cell_types) else ci,
'cell_type_idx': ci,
'raw_signal': float(row[3]),
'cell_count': int(row[4]),
'mean_signal': float(row[5]),
})
return pd.DataFrame.from_records(records)
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""
Convert coverage data to a pandas DataFrame.
Returns:
DataFrame with positions as index and cell types as columns
"""
if not self.coverage_matrix or not self.cell_types:
return pd.DataFrame()
# Server returns coverage_matrix as (cell_types x positions);
# transpose to (positions x cell_types) for the DataFrame
matrix = np.array(self.coverage_matrix)
if matrix.ndim == 2 and matrix.shape[1] != len(self.cell_types):
matrix = matrix.T
df = pd.DataFrame(
matrix,
columns=self.cell_types
)
if self.positions:
df.index = self.positions
df.index.name = 'position'
return df
[docs]
def get_filter_options(self) -> Dict[str, Any]:
"""
Get available filter options for this coverage result.
Returns:
Dictionary with filter options
"""
if not self.client or not self.job_id:
return {}
return self.client.get_coverage_filter_options(self.job_id)
[docs]
def download_wig(self, output_path: str, **filters) -> str:
"""
Download coverage data as a WIG file.
Args:
output_path: Path to save the WIG file
**filters: Optional metadata filters
Returns:
Path to the saved file
"""
if not self.client or not self.job_id:
raise MalvaAPIError("Client and job_id required to download WIG files")
return self.client.download_coverage_wig(self.job_id, output_path, **filters)
[docs]
def plot(self, cell_types: Optional[List[str]] = None, **kwargs):
"""
Plot coverage across the genomic region.
Args:
cell_types: Specific cell types to plot (default: all)
**kwargs: Additional arguments passed to matplotlib
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib required for plotting. Install with: pip install matplotlib")
df = self.to_dataframe()
if df.empty:
print("No coverage data to plot")
return
if cell_types:
available = [ct for ct in cell_types if ct in df.columns]
if not available:
raise ValueError(f"None of {cell_types} found. Available: {list(df.columns)}")
df = df[available]
fig, ax = plt.subplots(figsize=kwargs.get('figsize', (14, 6)))
for col in df.columns:
ax.plot(range(len(df)), df[col].values, label=col, alpha=0.8)
ax.set_xlabel('Genomic Position')
ax.set_ylabel('Coverage')
ax.set_title(f'Coverage: {self.region}' if self.region else 'Genomic Coverage')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.tight_layout()
return fig
def __repr__(self) -> str:
n_positions = len(self.positions)
n_cell_types = len(self.cell_types)
return (f"CoverageResult(region='{self.region}', "
f"positions={n_positions}, cell_types={n_cell_types}, "
f"job_id='{self.job_id}')")
def __len__(self) -> int:
return len(self.positions)
[docs]
class UMAPCoordinates:
"""
Lightweight container for UMAP coordinates from the coexpression API.
Wraps the compact parallel-array format returned by
``GET /api/coexpression/umap/<dataset_id>`` and provides conversion to
a pandas DataFrame and a simple scatter-plot method.
"""
[docs]
def __init__(self, raw_data: Dict[str, Any], client: Optional['MalvaClient'] = None):
"""
Initialize UMAPCoordinates.
Args:
raw_data: Raw response from the UMAP endpoint
client: MalvaClient instance for follow-up requests
"""
self.raw_data = raw_data
self.client = client
self.x = raw_data.get('x', [])
self.y = raw_data.get('y', [])
self.metacell_ids = raw_data.get('metacell_id', [])
self.n_cells = raw_data.get('n_cells', [])
self.samples = raw_data.get('sample', [])
self.clusters = raw_data.get('cluster', [])
[docs]
def to_dataframe(self) -> pd.DataFrame:
"""
Convert to a pandas DataFrame.
Returns:
DataFrame with columns: x, y, metacell_id, n_cells, sample, cluster
"""
data: Dict[str, Any] = {'x': self.x, 'y': self.y}
if self.metacell_ids:
data['metacell_id'] = self.metacell_ids
if self.n_cells:
data['n_cells'] = self.n_cells
if self.samples:
data['sample'] = self.samples
if self.clusters:
data['cluster'] = self.clusters
return pd.DataFrame(data)
[docs]
def plot(self, color_by: str = 'cluster', point_size: Optional[float] = None,
cmap: str = 'tab20', figsize: Tuple[int, int] = (10, 8)):
"""
Scatter plot of UMAP coordinates.
Args:
color_by: Column to color points by (default ``'cluster'``)
point_size: Marker size (auto-scaled if ``None``)
cmap: Matplotlib colormap name
figsize: Figure size as ``(width, height)``
Returns:
matplotlib Figure
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib required for plotting. Install with: pip install matplotlib")
df = self.to_dataframe()
if df.empty:
print("No UMAP data to plot")
return
fig, ax = plt.subplots(figsize=figsize)
if point_size is None:
point_size = max(1, 200 / (len(df) ** 0.5))
if color_by in df.columns:
col = df[color_by]
if col.dtype == 'object' or hasattr(col, 'cat'):
categories = col.unique()
color_map = {cat: i for i, cat in enumerate(categories)}
colors = col.map(color_map)
scatter = ax.scatter(df['x'], df['y'], c=colors, s=point_size,
cmap=cmap, alpha=0.7)
handles = [plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=plt.get_cmap(cmap)(color_map[cat] / max(len(categories) - 1, 1)),
markersize=8, label=str(cat))
for cat in categories]
ax.legend(handles=handles, bbox_to_anchor=(1.05, 1),
loc='upper left', fontsize='small')
else:
scatter = ax.scatter(df['x'], df['y'], c=col, s=point_size,
cmap=cmap, alpha=0.7)
plt.colorbar(scatter, ax=ax, label=color_by)
else:
ax.scatter(df['x'], df['y'], s=point_size, alpha=0.7)
ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_title(f'UMAP ({color_by})')
plt.tight_layout()
return fig
def __repr__(self) -> str:
return f"UMAPCoordinates(n_points={len(self.x)}, clusters={len(set(self.clusters)) if self.clusters else 'N/A'})"
def __len__(self) -> int:
return len(self.x)
[docs]
class CoexpressionResult:
"""
Full coexpression analysis result from the Malva coexpression API.
Wraps the response from ``POST /api/coexpression/query-by-job`` and
provides DataFrame conversions, top-gene retrieval, and plotting
helpers for correlated genes, GO enrichment, and UMAP scores.
"""
[docs]
def __init__(self, raw_data: Dict[str, Any], client: Optional['MalvaClient'] = None):
"""
Initialize CoexpressionResult.
Args:
raw_data: Raw response from the coexpression endpoint
client: MalvaClient instance for follow-up requests
"""
self.raw_data = raw_data
self.client = client
self.dataset_id = raw_data.get('dataset_id', '')
self.correlated_genes = raw_data.get('correlated_genes', [])
self.umap_scores = raw_data.get('umap_scores', {})
self.go_enrichment = raw_data.get('go_enrichment', [])
self.cell_type_enrichment = raw_data.get('cell_type_enrichment', [])
self.tissue_breakdown = raw_data.get('tissue_breakdown', [])
self.n_query_cells = raw_data.get('n_query_cells', 0)
self.n_mapped_metacells = raw_data.get('n_mapped_metacells', 0)
# -- DataFrame conversions ------------------------------------------------
[docs]
def genes_to_dataframe(self) -> pd.DataFrame:
"""
Convert correlated genes to a DataFrame.
Returns:
DataFrame with columns: gene, correlation, p_value (plus any
extra fields returned by the server)
"""
if not self.correlated_genes:
return pd.DataFrame(columns=['gene', 'correlation', 'p_value'])
return pd.DataFrame(self.correlated_genes)
[docs]
def scores_to_dataframe(self) -> pd.DataFrame:
"""
Convert UMAP scores to a DataFrame.
Returns:
DataFrame with metacell-level score data
"""
scores = self.umap_scores
if not scores:
return pd.DataFrame()
# Handle parallel-array (compact) format
if isinstance(scores, dict) and 'metacell_ids' in scores:
data: Dict[str, Any] = {}
for key, values in scores.items():
if isinstance(values, list):
data[key] = values
return pd.DataFrame(data)
# Handle list-of-dicts format
if isinstance(scores, list):
return pd.DataFrame(scores)
return pd.DataFrame()
[docs]
def umap_to_dataframe(self) -> pd.DataFrame:
"""
Convert UMAP score data to a DataFrame with x/y coordinates.
Falls back to :meth:`scores_to_dataframe` when coordinates are
embedded in the scores payload.
Returns:
DataFrame with UMAP coordinates and scores
"""
return self.scores_to_dataframe()
[docs]
def go_to_dataframe(self) -> pd.DataFrame:
"""
Convert GO enrichment results to a DataFrame.
Returns:
DataFrame with columns such as go_id, name, fdr, etc.
"""
if not self.go_enrichment:
return pd.DataFrame(columns=['go_id', 'name', 'fdr'])
return pd.DataFrame(self.go_enrichment)
[docs]
def cell_type_enrichment_to_dataframe(self) -> pd.DataFrame:
"""
Convert cell-type enrichment to a DataFrame.
Returns:
DataFrame with cell-type enrichment data
"""
if not self.cell_type_enrichment:
return pd.DataFrame()
return pd.DataFrame(self.cell_type_enrichment)
[docs]
def tissue_breakdown_to_dataframe(self) -> pd.DataFrame:
"""
Convert tissue breakdown to a DataFrame.
Returns:
DataFrame with tissue breakdown data
"""
if not self.tissue_breakdown:
return pd.DataFrame()
return pd.DataFrame(self.tissue_breakdown)
# -- Convenience ----------------------------------------------------------
[docs]
def get_top_genes(self, n: int = 20, min_correlation: float = 0.0) -> List[str]:
"""
Get the top *n* correlated gene names.
Args:
n: Number of genes to return
min_correlation: Minimum correlation to include
Returns:
List of gene names
"""
df = self.genes_to_dataframe()
if df.empty:
return []
if 'correlation' in df.columns:
df = df[df['correlation'] >= min_correlation]
df = df.sort_values('correlation', ascending=False)
gene_col = 'gene' if 'gene' in df.columns else df.columns[0]
return df[gene_col].head(n).tolist()
# -- Plotting -------------------------------------------------------------
[docs]
def plot_umap(self, color_by: str = 'positive_fraction', point_size: Optional[float] = None,
cmap: str = 'viridis', figsize: Tuple[int, int] = (10, 8)):
"""
Scatter plot of UMAP coordinates coloured by a score column.
Args:
color_by: Column in the scores data to use for colouring
point_size: Marker size (auto-scaled if ``None``)
cmap: Matplotlib colormap name
figsize: Figure size as ``(width, height)``
Returns:
matplotlib Figure
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib required for plotting. Install with: pip install matplotlib")
df = self.scores_to_dataframe()
if df.empty:
print("No UMAP score data to plot")
return
fig, ax = plt.subplots(figsize=figsize)
if point_size is None:
point_size = max(1, 200 / (len(df) ** 0.5))
x_col = 'umap_x' if 'umap_x' in df.columns else df.columns[0]
y_col = 'umap_y' if 'umap_y' in df.columns else df.columns[1]
if color_by in df.columns:
# resort if the data is numeric, to plot properly
if pd.api.types.is_any_real_numeric_dtype(df[color_by]):
df = df.sort_values(by=color_by, ascending=True)
scatter = ax.scatter(df[x_col], df[y_col], c=df[color_by],
s=point_size, cmap=cmap, alpha=0.7)
plt.colorbar(scatter, ax=ax, label=color_by)
else:
ax.scatter(df[x_col], df[y_col], s=point_size, alpha=0.7)
ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_title(f'Coexpression UMAP ({color_by})')
plt.tight_layout()
return fig
[docs]
def plot_top_genes(self, n: int = 20, figsize: Tuple[int, int] = (8, 6)):
"""
Horizontal bar chart of the top correlated genes.
Args:
n: Number of genes to show
figsize: Figure size as ``(width, height)``
Returns:
matplotlib Figure
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib required for plotting. Install with: pip install matplotlib")
df = self.genes_to_dataframe()
if df.empty:
print("No correlated gene data to plot")
return
if 'correlation' in df.columns:
df = df.sort_values('correlation', ascending=False).head(n)
else:
df = df.head(n)
gene_col = 'gene' if 'gene' in df.columns else df.columns[0]
corr_col = 'correlation' if 'correlation' in df.columns else df.columns[1]
fig, ax = plt.subplots(figsize=figsize)
y_pos = range(len(df))
ax.barh(y_pos, df[corr_col].values, color='steelblue')
ax.set_yticks(y_pos)
ax.set_yticklabels(df[gene_col].values)
ax.invert_yaxis()
ax.set_xlabel('Correlation')
ax.set_title(f'Top {len(df)} Correlated Genes')
plt.tight_layout()
return fig
[docs]
def plot_go_enrichment(self, n: int = 15, figsize: Tuple[int, int] = (8, 6)):
"""
Bar chart of GO enrichment results (−log10 FDR).
Args:
n: Number of GO terms to show
figsize: Figure size as ``(width, height)``
Returns:
matplotlib Figure
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("matplotlib required for plotting. Install with: pip install matplotlib")
df = self.go_to_dataframe()
if df.empty:
print("No GO enrichment data to plot")
return
fdr_col = 'fdr' if 'fdr' in df.columns else 'p_value' if 'p_value' in df.columns else None
name_col = 'name' if 'name' in df.columns else df.columns[0]
if fdr_col is None:
print("No FDR or p-value column found in GO enrichment data")
return
df = df.sort_values(fdr_col, ascending=True).head(n)
df['neg_log10_fdr'] = -np.log10(df[fdr_col].clip(lower=1e-300))
fig, ax = plt.subplots(figsize=figsize)
y_pos = range(len(df))
ax.barh(y_pos, df['neg_log10_fdr'].values, color='darkorange')
ax.set_yticks(y_pos)
ax.set_yticklabels(df[name_col].values)
ax.invert_yaxis()
ax.set_xlabel('-log10(FDR)')
ax.set_title(f'Top {len(df)} GO Enrichment Terms')
plt.tight_layout()
return fig
def __repr__(self) -> str:
parts = [f"CoexpressionResult(dataset='{self.dataset_id}'"]
parts.append(f"genes={len(self.correlated_genes)}")
parts.append(f"query_cells={self.n_query_cells}")
parts.append(f"metacells={self.n_mapped_metacells}")
if self.go_enrichment:
parts.append(f"go_terms={len(self.go_enrichment)}")
return ', '.join(parts) + ')'
def __len__(self) -> int:
return len(self.correlated_genes)