Skip to content

Commit ea43c83

Browse files
andrewfulton9krassowski
authored andcommitted
Adds unix shell-style wildcard matching to /learn (jupyterlab#989)
* adds wildcard matching to /learn * Add documentation * improve docs * cleanup * adds wildcard matching to /learn * Add documentation * improve docs * Update docs/source/users/index.md Co-authored-by: Michał Krassowski <[email protected]> * update for test * improve test * improve directory handling * remove dir only logic --------- Co-authored-by: Michał Krassowski <[email protected]>
1 parent f7be16e commit ea43c83

File tree

5 files changed

+111
-16
lines changed

5 files changed

+111
-16
lines changed

docs/source/users/index.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,13 @@ To teach Jupyter AI about a folder full of documentation, for example, run `/lea
499499
alt='Screen shot of "/learn docs/" command and a response.'
500500
class="screenshot" />
501501

502+
The `/learn` command also supports unix shell-style wildcard matching. This allows fine-grained file selection for learning. For example, to learn on only notebooks in all directories you can use `/learn **/*.ipynb` and all notebooks within your base (or preferred directory if set) will be indexed, while all other file extensions will be ignored.
503+
504+
:::{warning}
505+
:name: unix shell-style wildcard matching
506+
Certain patterns may cause `/learn` to run more slowly. For instance `/learn **` may cause directories to be walked multiple times in search of files.
507+
:::
508+
502509
You can then use `/ask` to ask a question specifically about the data that you taught Jupyter AI with `/learn`.
503510

504511
<img src="../_static/chat-ask-command.png"

packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import json
33
import os
4+
from glob import iglob
45
from typing import Any, Coroutine, List, Optional, Tuple
56

67
from dask.distributed import Client as DaskClient
@@ -180,9 +181,13 @@ async def process_message(self, message: HumanChatMessage):
180181
short_path = args.path[0]
181182
load_path = os.path.join(self.output_dir, short_path)
182183
if not os.path.exists(load_path):
183-
response = f"Sorry, that path doesn't exist: {load_path}"
184-
self.reply(response, message)
185-
return
184+
try:
185+
# check if globbing the load path will return anything
186+
next(iglob(load_path))
187+
except StopIteration:
188+
response = f"Sorry, that path doesn't exist: {load_path}"
189+
self.reply(response, message)
190+
return
186191

187192
# delete and relearn index if embedding model was changed
188193
await self.delete_and_relearn()
@@ -193,11 +198,16 @@ async def process_message(self, message: HumanChatMessage):
193198
load_path, args.chunk_size, args.chunk_overlap, args.all_files
194199
)
195200
except Exception as e:
196-
response = f"""Learn documents in **{load_path}** failed. {str(e)}."""
201+
response = """Learn documents in **{}** failed. {}.""".format(
202+
load_path.replace("*", r"\*"),
203+
str(e),
204+
)
197205
else:
198206
self.save()
199-
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
200-
You can ask questions about these docs by prefixing your message with **/ask**."""
207+
response = """🎉 I have learned documents at **%s** and I am ready to answer questions about them.
208+
You can ask questions about these docs by prefixing your message with **/ask**.""" % (
209+
load_path.replace("*", r"\*")
210+
)
201211
self.reply(response, message)
202212

203213
def _build_list_response(self):

packages/jupyter-ai/jupyter_ai/document_loaders/directory.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import tarfile
55
from datetime import datetime
6+
from glob import iglob
67
from pathlib import Path
78
from typing import List
89

@@ -109,6 +110,18 @@ def flatten(*chunk_lists):
109110
return list(itertools.chain(*chunk_lists))
110111

111112

113+
def walk_directory(directory, all_files):
114+
filepaths = []
115+
for dir, subdirs, filenames in os.walk(directory):
116+
# Filter out hidden filenames, hidden directories, and excluded directories,
117+
# unless "all files" are requested
118+
if not all_files:
119+
subdirs[:] = [d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)]
120+
filenames = [f for f in filenames if not f[0] == "."]
121+
filepaths += [Path(dir) / filename for filename in filenames]
122+
return filepaths
123+
124+
112125
def collect_filepaths(path, all_files: bool):
113126
"""Selects eligible files, i.e.,
114127
1. Files not in excluded directories, and
@@ -119,17 +132,13 @@ def collect_filepaths(path, all_files: bool):
119132
# Check if the path points to a single file
120133
if os.path.isfile(path):
121134
filepaths = [Path(path)]
135+
elif os.path.isdir(path):
136+
filepaths = walk_directory(path, all_files)
122137
else:
123138
filepaths = []
124-
for dir, subdirs, filenames in os.walk(path):
125-
# Filter out hidden filenames, hidden directories, and excluded directories,
126-
# unless "all files" are requested
127-
if not all_files:
128-
subdirs[:] = [
129-
d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)
130-
]
131-
filenames = [f for f in filenames if not f[0] == "."]
132-
filepaths.extend([Path(dir) / filename for filename in filenames])
139+
for glob_path in iglob(str(path), include_hidden=all_files, recursive=True):
140+
if os.path.isfile(glob_path):
141+
filepaths.append(Path(glob_path))
133142
valid_exts = {j.lower() for j in SUPPORTED_EXTS}
134143
filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts]
135144
return filepaths
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "85f55790-78a3-4fd2-bd0f-bf596e28a65c",
7+
"metadata": {},
8+
"outputs": [
9+
{
10+
"name": "stdout",
11+
"output_type": "stream",
12+
"text": [
13+
"hello world\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"print(\"hello world\")"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"id": "367c03ce-503f-4a2a-9221-c4fcd49b34c5",
25+
"metadata": {},
26+
"outputs": [],
27+
"source": []
28+
}
29+
],
30+
"metadata": {
31+
"kernelspec": {
32+
"display_name": "Python 3 (ipykernel)",
33+
"language": "python",
34+
"name": "python3"
35+
},
36+
"language_info": {
37+
"codemirror_mode": {
38+
"name": "ipython",
39+
"version": 3
40+
},
41+
"file_extension": ".py",
42+
"mimetype": "text/x-python",
43+
"name": "python",
44+
"nbconvert_exporter": "python",
45+
"pygments_lexer": "ipython3",
46+
"version": "3.11.0"
47+
}
48+
},
49+
"nbformat": 4,
50+
"nbformat_minor": 5
51+
}

packages/jupyter-ai/jupyter_ai/tests/test_directory.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
1717
file6_path = static_test_files_dir / "file3.csv"
1818
file7_path = static_test_files_dir / "file3.xyz"
1919
file8_path = static_test_files_dir / "file4.pdf"
20+
file9_path = static_test_files_dir / "file9.ipynb"
2021

2122
job_staging_dir = jp_ai_staging_dir / "TestDir"
2223
job_staging_dir.mkdir()
@@ -33,6 +34,7 @@ def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
3334
shutil.copy2(file6_path, job_staging_hiddendir)
3435
shutil.copy2(file7_path, job_staging_subdir)
3536
shutil.copy2(file8_path, job_staging_hiddendir)
37+
shutil.copy2(file9_path, job_staging_subdir)
3638

3739
return job_staging_dir
3840

@@ -49,8 +51,24 @@ def test_collect_filepaths(staging_dir):
4951
# Call the function we want to test
5052
result = collect_filepaths(staging_dir_filepath, all_files)
5153

52-
assert len(result) == 3 # Test number of valid files
54+
assert len(result) == 4 # Test number of valid files
5355

5456
filenames = [fp.name for fp in result]
5557
assert "file0.html" in filenames # Check that valid file is included
5658
assert "file3.xyz" not in filenames # Check that invalid file is excluded
59+
60+
# test unix wildcard pattern
61+
pattern_path = os.path.join(staging_dir_filepath, "**/*.*py*")
62+
results = collect_filepaths(pattern_path, all_files)
63+
assert len(results) == 2
64+
condition = lambda p: p.suffix in [".py", ".ipynb"]
65+
assert all(map(condition, results))
66+
67+
# test unix wildcard pattern returning only directories
68+
pattern_path = f"{str(staging_dir_filepath)}*/"
69+
results = collect_filepaths(pattern_path, all_files)
70+
assert len(result) == 4
71+
filenames = [fp.name for fp in result]
72+
73+
assert "file0.html" in filenames # Check that valid file is included
74+
assert "file3.xyz" not in filenames # Check that invalid file is excluded

0 commit comments

Comments
 (0)