Skip to content

Commit 7d32a9e

Browse files
Add ConTextTab model (#1528)
Adding [ConTextTab model](https://huggingface.co/sap-ai-research/contexttab) to library and snippets. The checkpoints are linked to `https://github.com/SAP-samples/contexttab` and a custom download stats filter is added. --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 4732260 commit 7d32a9e

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,62 @@ wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
132132
ta.save("test-2.wav", wav, model.sr)`,
133133
];
134134

135+
export const contexttab = (): string[] => {
136+
const installSnippet = `pip install git+https://github.com/SAP-samples/contexttab`;
137+
138+
const classificationSnippet = `# Run a classification task
139+
from sklearn.datasets import load_breast_cancer
140+
from sklearn.metrics import accuracy_score
141+
from sklearn.model_selection import train_test_split
142+
143+
from contexttab import ConTextTabClassifier
144+
145+
# Load sample data
146+
X, y = load_breast_cancer(return_X_y=True)
147+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
148+
149+
# Initialize a classifier
150+
# You can omit checkpoint and checkpoint_revision to use the default model
151+
clf = ConTextTabClassifier(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
152+
153+
clf.fit(X_train, y_train)
154+
155+
# Predict probabilities
156+
prediction_probabilities = clf.predict_proba(X_test)
157+
# Predict labels
158+
predictions = clf.predict(X_test)
159+
print("Accuracy", accuracy_score(y_test, predictions))`;
160+
161+
const regressionsSnippet = `# Run a regression task
162+
from sklearn.datasets import fetch_openml
163+
from sklearn.metrics import r2_score
164+
from sklearn.model_selection import train_test_split
165+
166+
from contexttab import ConTextTabRegressor
167+
168+
169+
# Load sample data
170+
df = fetch_openml(data_id=531, as_frame=True)
171+
X = df.data
172+
y = df.target.astype(float)
173+
174+
# Train-test split
175+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
176+
177+
# Initialize the regressor
178+
# You can omit checkpoint and checkpoint_revision to use the default model
179+
regressor = ConTextTabRegressor(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
180+
181+
regressor.fit(X_train, y_train)
182+
183+
# Predict on the test set
184+
predictions = regressor.predict(X_test)
185+
186+
r2 = r2_score(y_test, predictions)
187+
print("R² Score:", r2)`;
188+
return [installSnippet, classificationSnippet, regressionsSnippet];
189+
};
190+
135191
export const cxr_foundation = (): string[] => [
136192
`# pip install git+https://github.com/Google-Health/cxr-foundation.git#subdirectory=python
137193

packages/tasks/src/model-libraries.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
208208
repoUrl: "https://github.com/Unbabel/COMET/",
209209
countDownloads: `path:"hparams.yaml"`,
210210
},
211+
contexttab: {
212+
prettyLabel: "ConTextTab",
213+
repoName: "ConTextTab",
214+
repoUrl: "https://github.com/SAP-samples/contexttab",
215+
countDownloads: `path_extension:"pt"`,
216+
snippets: snippets.contexttab,
217+
},
211218
cosmos: {
212219
prettyLabel: "Cosmos",
213220
repoName: "Cosmos",

0 commit comments

Comments
 (0)