Skip to content

Commit 75c3872

Browse files
alinakbaseialarmedalien
authored andcommitted
Add Delta output and Spark optimization for GO annotation parser
1 parent 86a2b3b commit 75c3872

File tree

2 files changed

+485
-0
lines changed

2 files changed

+485
-0
lines changed

src/parsers/association_update.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
PySpark-based normalization pipeline for GO Gene Association Files (GAF).
3+
4+
This script processes a raw GAF-like annotation CSV (e.g., `annotations_data100.csv`)
5+
and produces a normalized output with a schema consistent with the previous
6+
Pandas-based implementation.
7+
8+
Two output formats are supported:
9+
10+
1) CSV Output:
11+
python3 association_update.py \
12+
--input annotations_data100.csv \
13+
--output normalized_annotation_update.csv
14+
15+
2) Delta Lake Output:
16+
python3 association_update.py \
17+
--input annotations_data100.csv \
18+
--output ./delta_output
19+
20+
Result:
21+
- A Parquet-backed Delta table containing normalized annotations.
22+
- Schema conforms to the CDM-style structured annotation model.
23+
"""
24+
25+
26+
# import os
27+
# import sys
28+
# import urllib.request
29+
30+
# import click
31+
32+
# from pyspark.sql import SparkSession
33+
# from pyspark.sql.functions import (
34+
# col, split, trim, when, upper, explode, lit, regexp_replace, to_date,
35+
# concat, concat_ws
36+
# )
37+
# from pyspark.sql.types import StringType
38+
# from delta import configure_spark_with_delta_pip
39+
40+
import os
41+
import sys
42+
import urllib.request
43+
import logging
44+
45+
import click
46+
47+
from pyspark.sql import SparkSession
48+
from pyspark.sql.functions import (
49+
col, split, trim, when, upper, explode, lit, regexp_replace,
50+
to_date, concat, concat_ws
51+
)
52+
from pyspark.sql.types import StringType
53+
from delta import configure_spark_with_delta_pip
54+
55+
56+
# ---------------------- Logging Setup ----------------------
57+
logging.basicConfig(level=logging.INFO)
58+
logger = logging.getLogger(__name__)
59+
60+
# --- Constants ---
61+
SUBJECT = "subject"
62+
PREDICATE = "predicate"
63+
OBJECT = "object"
64+
PUBLICATIONS = "publications"
65+
EVIDENCE_CODE = "Evidence_Code"
66+
SUPPORTING_OBJECTS = "supporting_objects"
67+
ANNOTATION_DATE = "annotation_date"
68+
PRIMARY_KNOWLEDGE_SOURCE = "primary_knowledge_source"
69+
AGGREGATOR = "aggregator"
70+
PROTOCOL_ID = "protocol_id"
71+
NEGATED = "negated"
72+
EVIDENCE_TYPE = "evidence_type"
73+
74+
# GAF Field Names
75+
DB = "DB"
76+
DB_OBJ_ID = "DB_Object_ID"
77+
QUALIFIER = "Qualifier"
78+
GO_ID = "GO_ID"
79+
DB_REF = "DB_Reference"
80+
WITH_FROM = "With_From"
81+
DATE = "Date"
82+
ASSIGNED_BY = "Assigned_By"
83+
84+
# ECO Mapping
85+
ECO_MAPPING_URL = "http://purl.obolibrary.org/obo/eco/gaf-eco-mapping.txt"
86+
87+
ALLOWED_PREDICATES = [
88+
"enables", "contributes_to", "acts_upstream_of_or_within", "involved_in",
89+
"acts_upstream_of", "acts_upstream_of_positive_effect", "acts_upstream_of_negative_effect",
90+
"acts_upstream_of_or_within_negative_effect", "acts_upstream_of_or_within_positive_effect",
91+
"located_in", "part_of", "is_active_in", "colocalizes_with"
92+
]
93+
94+
95+
def get_spark():
96+
"""Initialize and return a Spark session configured for Delta Lake."""
97+
builder = (
98+
SparkSession.builder
99+
.appName("GO-GAF-Spark-Parser")
100+
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
101+
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
102+
.config("spark.sql.shuffle.partitions", "200")
103+
)
104+
return configure_spark_with_delta_pip(builder).getOrCreate()
105+
106+
107+
def load_annotation(spark, input_path):
108+
"""Load and preprocess raw annotation CSV."""
109+
df = spark.read.csv(input_path, header=True, inferSchema=True)
110+
111+
df = df.select(DB, DB_OBJ_ID, QUALIFIER, GO_ID, DB_REF, EVIDENCE_CODE,
112+
WITH_FROM, DATE, ASSIGNED_BY)
113+
114+
df = df.withColumn(PREDICATE, col(QUALIFIER)) \
115+
.withColumn(OBJECT, col(GO_ID)) \
116+
.withColumn(PUBLICATIONS, split(trim(when(col(DB_REF).isNotNull(), col(DB_REF)).otherwise(lit(""))), "\\|")) \
117+
.withColumn(SUPPORTING_OBJECTS, split(trim(col(WITH_FROM)), "\\|")) \
118+
.withColumn(ANNOTATION_DATE, col(DATE)) \
119+
.withColumn(PRIMARY_KNOWLEDGE_SOURCE, col(ASSIGNED_BY))
120+
121+
return df
122+
123+
124+
def normalize_dates(df):
125+
"""Normalize annotation dates to yyyy-MM-dd format if 8-digit string."""
126+
df = df.withColumn(
127+
ANNOTATION_DATE,
128+
when(col(ANNOTATION_DATE).rlike("^[0-9]{8}$"),
129+
to_date(col(ANNOTATION_DATE), "yyyyMMdd"))
130+
)
131+
return df
132+
133+
134+
def process_predicates(df):
135+
"""Validate and clean predicate values (e.g., remove NOT| prefix)."""
136+
df = df.withColumn(NEGATED, col(PREDICATE).startswith("NOT|")) \
137+
.withColumn(PREDICATE, regexp_replace(col(PREDICATE), "^NOT\\|", ""))
138+
139+
invalid = df.filter(~col(PREDICATE).isin(ALLOWED_PREDICATES))
140+
if invalid.count() > 0:
141+
invalid_values = [r[PREDICATE] for r in invalid.select(PREDICATE).distinct().collect()]
142+
raise ValueError(f"Invalid predicate found {invalid_values}")
143+
return df
144+
145+
146+
def add_metadata(df):
147+
"""Add aggregator, protocol ID, and subject URI."""
148+
return (
149+
df.withColumn(AGGREGATOR, lit("UniProt"))
150+
.withColumn(PROTOCOL_ID, lit(None).cast(StringType()))
151+
.withColumn(SUBJECT, concat(col(DB).cast("string"), lit(":"), col(DB_OBJ_ID).cast("string")))
152+
)
153+
154+
155+
def load_eco_mapping(spark, local_path="gaf-eco-mapping.txt"):
156+
"""Download and load ECO evidence mapping table."""
157+
158+
if not os.path.exists(local_path):
159+
print(f"Downloading ECO mapping file to: {local_path}")
160+
urllib.request.urlretrieve(ECO_MAPPING_URL, local_path)
161+
162+
df = spark.read.csv(local_path, sep="\t", comment="#", header=False)
163+
return df.toDF(EVIDENCE_CODE, DB_REF, EVIDENCE_TYPE)
164+
165+
166+
def merge_evidence(df, eco):
167+
"""Join annotation DataFrame with ECO evidence mapping."""
168+
df = (
169+
df.withColumn(PUBLICATIONS, explode(col(PUBLICATIONS))) \
170+
.filter(col(PUBLICATIONS).isNotNull() & (col(PUBLICATIONS) != "")) \
171+
.withColumn(PUBLICATIONS, upper(trim(col(PUBLICATIONS))))
172+
)
173+
174+
eco = (
175+
eco.withColumn(DB_REF, upper(trim(col(DB_REF))))
176+
.withColumn(EVIDENCE_CODE, upper(trim(col(EVIDENCE_CODE))))
177+
)
178+
179+
merged = df.alias("df").join(
180+
eco.alias("eco"),
181+
on=(
182+
col("df." + EVIDENCE_CODE) == col("eco." + EVIDENCE_CODE)) &
183+
(col("df." + PUBLICATIONS) == col("eco." + DB_REF)),
184+
how="left"
185+
).drop(col("eco." + DB_REF)).drop(col("eco." + EVIDENCE_CODE))
186+
187+
fallback = (
188+
eco.filter(col(DB_REF) == "DEFAULT")
189+
.select(EVIDENCE_CODE, EVIDENCE_TYPE)
190+
.withColumnRenamed(EVIDENCE_TYPE, "fallback")
191+
)
192+
193+
merged = (
194+
merged.join(fallback, on=EVIDENCE_CODE, how="left")
195+
.withColumn(EVIDENCE_TYPE,
196+
when(col(EVIDENCE_TYPE).isNull(), col("fallback")).otherwise(col(EVIDENCE_TYPE)))
197+
.drop("fallback")
198+
)
199+
200+
return merged
201+
202+
203+
def reorder_columns(df):
204+
"""Ensure correct column order and clean up types."""
205+
df = (
206+
df.withColumn(PUBLICATIONS, concat_ws("|", col(PUBLICATIONS)))
207+
.withColumn(SUPPORTING_OBJECTS, concat_ws("|", col(SUPPORTING_OBJECTS)))
208+
.withColumn(SUPPORTING_OBJECTS, when(col(SUPPORTING_OBJECTS) == "", None).otherwise(col(SUPPORTING_OBJECTS)))
209+
.withColumn(NEGATED, col(NEGATED).cast("boolean").cast("string"))
210+
)
211+
212+
final_cols = [
213+
OBJECT, DB, ANNOTATION_DATE, PREDICATE, EVIDENCE_CODE,
214+
PUBLICATIONS, DB_OBJ_ID, PRIMARY_KNOWLEDGE_SOURCE,
215+
SUPPORTING_OBJECTS, AGGREGATOR, PROTOCOL_ID, NEGATED,
216+
SUBJECT, EVIDENCE_TYPE
217+
]
218+
return df.select([col(c) for c in final_cols])
219+
220+
221+
def write_output(df, output_path, mode="overwrite"):
222+
df.write.format("delta").mode(mode).save(output_path)
223+
224+
225+
def register_table(spark, output_path, table_name="normalized_annotation", permanent=True):
226+
if permanent:
227+
logger.info(f"Registering Delta table as permanent table {table_name}")
228+
229+
spark.sql(f"""
230+
CREATE TABLE IF NOT EXISTS {table_name}
231+
USING DELTA
232+
LOCATION '{output_path}'
233+
""")
234+
235+
else:
236+
logger.info(f"Registering Delta table as temporary view: {table_name}")
237+
df = spark.read.format("delta").load(output_path)
238+
df.createOrReplaceTempView(table_name)
239+
240+
241+
def run(input_path, output_path, register=False, table_name="normalized_annotation", permanent=True, dry_run=False, mode="overwrite"):
242+
spark = None
243+
try:
244+
spark = get_spark()
245+
logger.info("Starting annotation pipeline")
246+
247+
eco = load_eco_mapping(spark)
248+
df = load_annotation(spark, input_path)
249+
df = normalize_dates(df)
250+
df = process_predicates(df)
251+
df = add_metadata(df)
252+
df = merge_evidence(df, eco)
253+
df = reorder_columns(df)
254+
255+
if dry_run:
256+
logger.info("showing top 5 rows")
257+
df.show(5, truncate=False)
258+
else:
259+
write_output(df, output_path, mode=mode)
260+
logger.info(f"Data written to {output_path}")
261+
if register:
262+
register_table(spark, output_path, table_name=table_name, permanent=permanent)
263+
264+
except Exception as e:
265+
logger.error(f"Pipeline failed: {e}")
266+
sys.exit(1)
267+
finally:
268+
if spark:
269+
spark.stop()
270+
271+
272+
@click.command()
273+
@click.option("--input", "-i", required=True, help="Path to input CSV file")
274+
@click.option("--output", "-o", required=True, help="Target Delta table output directory")
275+
@click.option("--register", is_flag=True, help="Register the output as Spark SQL table")
276+
@click.option("--table-name", default="normalized_annotation", help="SQL table name to register")
277+
@click.option("--temp", is_flag=True, help="Register as temporary view (default is permanent)")
278+
@click.option("--mode", default="overwrite", type=click.Choice(["overwrite", "append", "ignore"]), help="Delta write mode")
279+
@click.option("--dry-run", is_flag=True, help="Dry run without writing output")
280+
281+
282+
def main(input, output, register, table_name, temp, mode, dry_run):
283+
if not os.path.isfile(input):
284+
logger.error(f"Input file does not exist: {input}")
285+
sys.exit(1)
286+
run(input, output, register=register, table_name=table_name, permanent=not temp, dry_run=dry_run, mode=mode)
287+
288+
289+
if __name__ == "__main__":
290+
main()
291+
292+

0 commit comments

Comments
 (0)