Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 78 additions & 71 deletions bin/sample_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Description: this script calculates the clonality of a TCR repertoire

@author: Domenick Braccia
@author: Dylan Tamayo, Domenick Braccia
@contributor: elhanaty
"""

Expand All @@ -14,8 +14,41 @@
import numpy as np
import csv
import re
import json

def calc_sample_stats(sample_meta, counts):
def extract_trb_family(allele):
if pd.isna(allele):
return None
match = re.match(r'(TRB[V|D|J])(\d+)', allele)
return f"{match.group(1)}{match.group(2)}" if match else None

def compute_gene_family_table(counts, col_name, all_families, sample_meta):
fam_col = f"{col_name}FamilyName"
counts[fam_col] = counts[col_name].apply(extract_trb_family)
fam_df = counts[fam_col].value_counts(dropna=False).to_frame().T.sort_index(axis=1)
fam_df = fam_df.reindex(columns=all_families, fill_value=0)

for col in ['origin', 'timepoint', 'subject_id']:
fam_df.insert(0, col, sample_meta[col])

return fam_df

def calc_gene_family(counts, gene_column, family_prefix, max_index, output_file, meta_df):
# Build list of all possible family names
all_fams = [f'{family_prefix}{i}' for i in range(1, max_index + 1)]

# Count usage
fam_df = counts[gene_column].apply(extract_trb_family).value_counts(dropna=False).to_frame().T

# Reindex to include all families
fam_df = pd.DataFrame([fam_df.reindex(columns=all_fams, fill_value=0).iloc[0]]).reset_index(drop=True)

# Add metadata columns
fam_df = pd.concat([meta_df, fam_df], axis=1)

fam_df.to_csv(output_file, header=True, index=False)

def calc_sample_stats(meta_df, counts, output_file):
"""Calculate sample level statistics of TCR repertoire."""

## first pass stats
Expand Down Expand Up @@ -54,72 +87,30 @@ def calc_sample_stats(sample_meta, counts):
## calculate ratio of convergent TCRs to total TCRs
ratio_convergent = num_convergent/len(aas)

## add in patient meta_data such as responder status to sample_stats.csv
# read in metadata file
# meta_data = pd.read_csv(args.meta_data, sep=',', header=0)

# filter out metadata for the current sample
# current_meta = meta_data[meta_data['patient_id'] == sample_meta[1]]

# write above values to csv file
with open('sample_stats.csv', 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([sample_meta[0], sample_meta[1], sample_meta[2], sample_meta[3],
num_clones, num_TCRs, simpson_index, simpson_index_corrected, clonality,
num_prod, num_nonprod, pct_prod, pct_nonprod,
productive_cdr3_avg_len, num_convergent, ratio_convergent])

# store v_family gene usage in a dataframe
def extract_trb_family(allele):
if pd.isna(allele):
return None
match = re.match(r'(TRB[V|D|J])(\d+)', allele)
return f"{match.group(1)}{match.group(2)}" if match else None

# Apply to each column
counts['vFamilyName'] = counts['v_call'].apply(extract_trb_family)
counts['dFamilyName'] = counts['d_call'].apply(extract_trb_family)
counts['jFamilyName'] = counts['j_call'].apply(extract_trb_family)

# Compute gene usage frequency per family
v_family = counts['vFamilyName'].value_counts(dropna=False).to_frame().T.sort_index(axis=1)
d_family = counts['dFamilyName'].value_counts(dropna=False).to_frame().T.sort_index(axis=1)
j_family = counts['jFamilyName'].value_counts(dropna=False).to_frame().T.sort_index(axis=1)

# generate a list of all possible columns names from TRBV1-TRBV30
all_v_fam = [f'TRBV{i}' for i in range(1, 31)]

# generate a list of all possible columns names from TRBD1-TRBD2
all_d_fam = [f'TRBD{i}' for i in range(1, 3)]

# generate a list of all possible columns names from TRBJ1-TRBJ2
all_j_fam = [f'TRBJ{i}' for i in range(1, 3)]

# add missing columns to v_family dataframe by reindexing
v_family_reindex = v_family.reindex(columns=all_v_fam, fill_value=0)
d_family_reindex = d_family.reindex(columns=all_d_fam, fill_value=0)
j_family_reindex = j_family.reindex(columns=all_j_fam, fill_value=0)

# add sample_meta columns to v_family_reindex and make them the first three columns
v_family_reindex.insert(0, 'origin', sample_meta[3])
v_family_reindex.insert(0, 'timepoint', sample_meta[2])
v_family_reindex.insert(0, 'patient_id', sample_meta[1])
d_family_reindex.insert(0, 'origin', sample_meta[3])
d_family_reindex.insert(0, 'timepoint', sample_meta[2])
d_family_reindex.insert(0, 'patient_id', sample_meta[1])
j_family_reindex.insert(0, 'origin', sample_meta[3])
j_family_reindex.insert(0, 'timepoint', sample_meta[2])
j_family_reindex.insert(0, 'patient_id', sample_meta[1])

# Write v_family_reindex to csv file with no header and no index
v_family_reindex.to_csv('v_family.csv', header=False, index=False)
d_family_reindex.to_csv('d_family.csv', header=False, index=False)
j_family_reindex.to_csv('j_family.csv', header=False, index=False)

# # store dictionaries in a list and output to pickle file
# gene_usage = [v_family, d_family, j_family] ## excluding v_genes, d_genes, j_genes
# with open('gene_usage_' + str(metadata[1] + '_' + str(metadata[2] + '_' + str(metadata[3]))) + '.pkl', 'wb') as f:
# pickle.dump(gene_usage, f)
row_data = {
'num_clones': num_clones,
'num_TCRs': num_TCRs,
'simpson_index': simpson_index,
'simpson_index_corrected': simpson_index_corrected,
'clonality': clonality,
'num_prod': num_prod,
'num_nonprod': num_nonprod,
'pct_prod': pct_prod,
'pct_nonprod': pct_nonprod,
'productive_cdr3_avg_len': productive_cdr3_avg_len,
'num_convergent': num_convergent,
'ratio_convergent': ratio_convergent
}

# Convert to single-row dataframe
df_stats = pd.DataFrame([row_data])

# Add metadata columns
df_stats = pd.concat([meta_df, df_stats], axis=1)

# Save to CSV
df_stats.to_csv(output_file, header=True, index=False)


def main():
# initialize parser
Expand All @@ -129,7 +120,7 @@ def main():
parser.add_argument('-s', '--sample_meta',
metavar='sample_meta',
type=str,
help='sample metadata passed in through samples CSV file')
help='sample metadata passed in as json format')
parser.add_argument('-c', '--count_table',
metavar='count_table',
type=argparse.FileType('r'),
Expand All @@ -138,12 +129,28 @@ def main():
args = parser.parse_args()

## convert metadata to list
sample_meta = args.sample_meta[1:-1].split(', ')
sample_meta = json.loads(args.sample_meta)

# Read in the counts file
counts = pd.read_csv(args.count_table, sep='\t', header=0)

calc_sample_stats(sample_meta, counts)
# Build metadata row from selected keys
meta_keys = ['subject_id', 'timepoint', 'origin']
meta_row = {k: sample_meta[k] for k in meta_keys}
meta_df = pd.DataFrame([meta_row])

sample = sample_meta['sample']

calc_gene_family(counts, 'v_call', 'TRBV', 30, f'vdj/v_family_{sample}.csv', meta_df)
calc_gene_family(counts, 'd_call', 'TRBD', 2, f'vdj/d_family_{sample}.csv', meta_df)
calc_gene_family(counts, 'j_call', 'TRBJ', 2, f'vdj/j_family_{sample}.csv', meta_df)

# Build metadata row from selected keys
meta_keys = ['sample', 'subject_id', 'timepoint', 'origin']
meta_row = {k: sample_meta[k] for k in meta_keys}
meta_df = pd.DataFrame([meta_row])

calc_sample_stats(meta_df, counts, f'stats/sample_stats_{sample}.csv')

if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions modules/local/sample/sample_aggregate.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
process SAMPLE_AGGREGATE {
tag "${output_file}"
label 'process_low'
container "ghcr.io/karchinlab/tcrtoolkit:main"

input:
path csv_files
val output_file

output:
path output_file, emit: aggregated_csv

script:
"""
python3 <<EOF
import pandas as pd
input_files = [${csv_files.collect { '"' + it.getName() + '"' }.join(', ')}]
dfs = [pd.read_csv(f) for f in input_files]
merged = pd.concat(dfs, axis=0, ignore_index=True)
merged.to_csv("${output_file}", index=False)
EOF
"""
Comment on lines +14 to +23
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aggregation does not enforce a deterministic ordering of rows (prior implementation used sort:true), so run-to-run concatenation order may vary depending on file staging order. Apply a consistent sort (e.g., sort input_files list or sort merged by key columns) before writing to ensure reproducible outputs.

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if concatenation order is important as stated by Copilot above, but is there a reason why we need SAMPLE_AGGREGATE as separate python-based process as opposed to seemingly simpler collectFile operator?

}
18 changes: 11 additions & 7 deletions modules/local/sample/sample_calc.nf
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@ process SAMPLE_CALC {
tuple val(sample_meta), path(count_table)

output:
path 'sample_stats.csv' , emit: sample_csv
path 'v_family.csv' , emit: v_family_csv
path 'd_family.csv' , emit: d_family_csv
path 'j_family.csv' , emit: j_family_csv
val sample_meta , emit: sample_meta
path "stats/sample_stats_${sample_meta.sample}.csv" , emit: sample_csv
path "vdj/v_family_${sample_meta.sample}.csv" , emit: v_family_csv
path "vdj/d_family_${sample_meta.sample}.csv" , emit: d_family_csv
path "vdj/j_family_${sample_meta.sample}.csv" , emit: j_family_csv
val sample_meta , emit: sample_meta

script:
def meta_json = groovy.json.JsonOutput.toJson(sample_meta)

"""
echo '' > sample_stats.csv
sample_calc.py -s '${sample_meta}' -c ${count_table}
mkdir -p stats
mkdir -p vdj

sample_calc.py -s '${meta_json}' -c ${count_table}
"""

stub:
Expand Down
11 changes: 10 additions & 1 deletion modules/local/sample/tcrdist3.nf
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@ process TCRDIST3_MATRIX {
tag "${sample_meta.sample}"
container "ghcr.io/karchinlab/tcrtoolkit:main"

cpus params.max_cpus
cpus {
if (task.memory > 256.GB)
return 16 * task.attempt
else if (task.memory > 64.GB)
return 8 * task.attempt
else if (task.memory > 4.GB)
return 4 * task.attempt
else
return 2 * task.attempt
}
Comment on lines +6 to +14
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cpus directive depends on task.memory, which itself is derived later by the memory block; this circular dependency can result in task.memory being undefined or default when cpus is evaluated. Derive both cpu and memory from the same underlying metric (e.g., count_table.size()) or compute memory first in a variable and base cpus on that variable instead of task.memory.

Suggested change
if (task.memory > 256.GB)
return 16 * task.attempt
else if (task.memory > 64.GB)
return 8 * task.attempt
else if (task.memory > 4.GB)
return 4 * task.attempt
else
return 2 * task.attempt
}
def sz = count_table.size()
def mb = 1024 * 1024
if (sz > 26 * mb)
return 16 * task.attempt
else if (sz > 20 * mb)
return 8 * task.attempt
else if (sz > 10 * mb)
return 4 * task.attempt
else
return 2 * task.attempt
}

Copilot uses AI. Check for mistakes.

memory {
def sz = count_table.size()
def mb = 1024 * 1024
Expand Down
45 changes: 17 additions & 28 deletions notebooks/sample_stats_template.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,17 @@ print('Date and time: ' + str(datetime.datetime.now()))
# 4. Loading data

## reading combined repertoire statistics
df = pd.read_csv(sample_stats_csv, sep=',', header=None,
names=['sample_id', 'patient_id', 'timepoint', 'origin',
'num_clones', 'num_TCRs', 'simpson_index', 'simpson_index_corrected', 'clonality',
'num_prod', 'num_nonprod', 'pct_prod', 'pct_nonprod',
'productive_cdr3_avg_len', 'num_convergent', 'ratio_convergent'])
df = pd.read_csv(sample_stats_csv, sep=',')
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the file was read with header=None and an explicit column name list; removing that implies the CSV now contains a header row, but upstream changes do not show evidence that a header has been added. If the file still lacks a header, the first data row will be misinterpreted as column names—retain header=None with explicit names or confirm the writer now emits headers.

Suggested change
df = pd.read_csv(sample_stats_csv, sep=',')
# Replace the column names below with the actual column names for sample_stats_csv
df = pd.read_csv(sample_stats_csv, sep=',', header=None, names=['col1', 'col2', 'col3', 'col4'])

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

@dimalvovs dimalvovs Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These headers are now present in the sample_stats.csv (output of the minimal test run), so I think this is not an issue.

sample,subject_id,timepoint,origin,num_clones,num_TCRs,simpson_index,simpson_index_corrected,clonality,num_prod,num_nonprod,pct_prod,pct_nonprod,productive_cdr3_avg_len,num_convergent,ratio_convergent

# print('-- Imported sample_stats_csv as `df`...')

## reading sample metadata
meta = pd.read_csv(sample_table, sep=',', header=None, index_col=None,
names=['sample_id', 'file', 'patient_id', 'timepoint', 'origin'])
names=['sample', 'file', 'subject_id', 'timepoint', 'origin'])
# print('-- Imported sample_table as `meta`...')

## reading V gene family usage
v_family = pd.read_csv(v_family_csv, sep=',', header=None, index_col=None,
names=['patient_id', 'timepoint', 'origin', 'TCRBV01',
'TCRBV02', 'TCRBV03', 'TCRBV04', 'TCRBV05', 'TCRBV06',
'TCRBV07', 'TCRBV08', 'TCRBV09', 'TCRBV10', 'TCRBV11',
'TCRBV12', 'TCRBV13', 'TCRBV14', 'TCRBV15', 'TCRBV16',
'TCRBV17', 'TCRBV18', 'TCRBV19', 'TCRBV20', 'TCRBV21',
'TCRBV22', 'TCRBV23', 'TCRBV24', 'TCRBV25', 'TCRBV26',
'TCRBV27', 'TCRBV28', 'TCRBV29', 'TCRBV30'])
v_family = v_family.sort_values(by=['patient_id', 'timepoint'])
v_family = pd.read_csv(v_family_csv, sep=',')
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The earlier version explicitly supplied the V gene column names with header=None; switching to default header handling assumes the aggregated v_family CSV now has a header row. If it does not, sorting by 'subject_id' will raise a KeyError because that column name will instead be a data value—either restore header=None with explicit names or ensure headers are written upstream.

Suggested change
v_family = pd.read_csv(v_family_csv, sep=',')
v_family = pd.read_csv(v_family_csv, sep=',', header=None, index_col=None,
names=['sample', 'file', 'subject_id', 'timepoint', 'origin', 'v_family', 'proportion'])

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The individual stats tables contain the columns that were present for each sample, so I think this is not an issue but rather a feature - column names are not hardcoded anymore.
Example:

subject_id,timepoint,origin,TRBD1,TRBD2

v_family = v_family.sort_values(by=['subject_id', 'timepoint'])
```

# Sample level statistics {#sec-sample-level-stats}
Expand All @@ -106,7 +95,7 @@ fig = px.box(df,
x = 'timepoint',
y='num_clones',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -120,7 +109,7 @@ fig = px.box(df,
x = 'timepoint',
y='clonality',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -134,7 +123,7 @@ fig = px.box(df,
x = 'timepoint',
y='simpson_index_corrected',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -152,7 +141,7 @@ fig = px.box(df,
x = 'timepoint',
y='pct_prod',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -170,7 +159,7 @@ fig = px.box(df,
x = 'timepoint',
y='productive_cdr3_avg_len',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -184,7 +173,7 @@ fig = px.box(df,
x = 'timepoint',
y='ratio_convergent',
color='origin',
points='all', hover_data=['sample_id'],
points='all', hover_data=['sample'],
category_orders={'timepoint': timepts})
fig.show()
```
Expand All @@ -210,16 +199,16 @@ where $N_{k}$ is the number of TCRs that use the $k$ th V gene, and T is the tot
colors = ["#fafa70","#fdef6b","#ffe566","#ffda63","#ffd061","#ffc660","#ffbb5f","#fdb15f","#fba860","#f79e61","#f39562","#ef8c63","#e98365","#e37b66","#dd7367","#d66b68","#ce6469","#c65e6a","#bd576b","#b4526b","#ab4c6b","#a1476a","#974369","#8c3e68","#823a66","#773764","#6d3361","#62305e","#572c5a","#4d2956"]

## calculate calulate proportions and add to v_family_long
v_family_long = pd.melt(v_family, id_vars=['patient_id', 'timepoint', 'origin'], value_vars=v_family.columns[3:], var_name='v_gene', value_name='count')
v_family_long['proportion'] = v_family_long.groupby(['patient_id', 'timepoint', 'origin'])['count'].transform(lambda x: x / x.sum())
v_family_long = pd.melt(v_family, id_vars=['subject_id', 'timepoint', 'origin'], value_vars=v_family.columns[3:], var_name='v_gene', value_name='count')
v_family_long['proportion'] = v_family_long.groupby(['subject_id', 'timepoint', 'origin'])['count'].transform(lambda x: x / x.sum())

## add in the total number of v genes for each sample
total_v_genes = v_family_long.groupby(['patient_id', 'timepoint', 'origin'])['count'].sum().reset_index()
total_v_genes.columns = ['patient_id', 'timepoint', 'origin', 'total_v_genes']
v_family_long = pd.merge(v_family_long, total_v_genes, on=['patient_id', 'timepoint', 'origin'])
total_v_genes = v_family_long.groupby(['subject_id', 'timepoint', 'origin'])['count'].sum().reset_index()
total_v_genes.columns = ['subject_id', 'timepoint', 'origin', 'total_v_genes']
v_family_long = pd.merge(v_family_long, total_v_genes, on=['subject_id', 'timepoint', 'origin'])

for patient in v_family_long.patient_id.unique().tolist():
current = v_family_long[v_family_long.patient_id == patient]
for patient in v_family_long.subject_id.unique().tolist():
current = v_family_long[v_family_long.subject_id == patient]
fig = go.Figure()
fig.update_layout(
template="simple_white",
Expand Down
Loading