Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/common/datamanagers/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class resources:
get_resource_uuid_from_slug = ro_txn_wrap(resources_dm.get_resource_uuid_from_slug)
resource_exists = ro_txn_wrap(resources_dm.resource_exists)
slug_exists = ro_txn_wrap(resources_dm.slug_exists)
get_all_field_ids = ro_txn_wrap(resources_dm.get_all_field_ids)


class labelset:
Expand Down
51 changes: 39 additions & 12 deletions nucliadb/src/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import yaml
from pydantic import BaseModel

from nucliadb.common import datamanagers
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
from nucliadb.common.maindb.utils import get_driver
from nucliadb.common.models_utils import from_proto
Expand Down Expand Up @@ -589,18 +590,7 @@ async def field_extension_prompt_context(
if resource_uuid not in ordered_resources:
ordered_resources.append(resource_uuid)

# Fetch the extracted texts of the specified fields for each resource
extend_fields = strategy.fields
extend_field_ids = []
for resource_uuid in ordered_resources:
for field_id in extend_fields:
try:
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
extend_field_ids.append(fid)
except ValueError: # pragma: no cover
# Invalid field id, skiping
continue

extend_field_ids = await get_matching_field_ids(kbid, ordered_resources, strategy)
tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
field_extracted_texts = await run_concurrently(tasks)

Expand Down Expand Up @@ -630,6 +620,43 @@ async def field_extension_prompt_context(
context[paragraph.id] = _clean_paragraph_text(paragraph)


async def get_matching_field_ids(
kbid: str, ordered_resources: list[str], strategy: FieldExtensionStrategy
) -> list[FieldId]:
extend_field_ids: list[FieldId] = []
# Fetch the extracted texts of the specified fields for each resource
for resource_uuid in ordered_resources:
for field_id in strategy.fields:
try:
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
extend_field_ids.append(fid)
except ValueError: # pragma: no cover
# Invalid field id, skiping
continue
if len(strategy.data_augmentation_field_prefixes) > 0:
for resource_uuid in ordered_resources:
all_field_ids = await datamanagers.atomic.resources.get_all_field_ids(
kbid=kbid, rid=resource_uuid, for_update=False
)
if all_field_ids is None:
continue
for fieldid in all_field_ids.fields:
# Generated fields are always text fields starting with "da-"
if any(
(
fieldid.field_type == resources_pb2.FieldType.TEXT
and fieldid.field.startswith(f"da-{prefix}-")
)
for prefix in strategy.data_augmentation_field_prefixes
):
extend_field_ids.append(
FieldId.from_pb(
rid=resource_uuid, field_type=fieldid.field_type, key=fieldid.field
)
)
return extend_field_ids


async def get_orm_field(kbid: str, field_id: FieldId) -> Optional[Field]:
resource = await cache.get_resource(kbid, field_id.rid)
if resource is None: # pragma: no cover
Expand Down
44 changes: 41 additions & 3 deletions nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,36 @@ async def test_ask_full_resource_rag_strategy_with_exclude(

@pytest.mark.deploy_modes("standalone")
async def test_ask_rag_options_extend_with_fields(
nucliadb_reader: AsyncClient, standalone_knowledgebox: str, resources
nucliadb_ingest_grpc: WriterStub,
nucliadb_writer: AsyncClient,
nucliadb_reader: AsyncClient,
standalone_knowledgebox: str,
resources,
):
resource1, resource2 = resources

# Create a 'fake' data augmentation field
resp = await nucliadb_writer.post(
f"/kb/{standalone_knowledgebox}/resources",
json={
"title": "The title DA",
"texts": {
"da-simpson-augmented-field": {
"body": "This is a data augmentation field content",
}
},
},
)
assert resp.status_code == 201, resp.text
rid = resp.json()["uuid"]
bmb = BrokerMessageBuilder(
kbid=standalone_knowledgebox, rid=rid, source=wpb2.BrokerMessage.MessageSource.PROCESSOR
)
bmb_fb = bmb.field_builder("da-simpson-augmented-field", field_type=rpb2.FieldType.TEXT)
bmb_fb.with_extracted_text("This is a data augmentation field content")
bm = bmb.build()
await inject_message(nucliadb_ingest_grpc, bm)

predict = get_predict()
predict.calls.clear() # type: ignore

Expand All @@ -460,7 +486,13 @@ async def test_ask_rag_options_extend_with_fields(
json={
"query": "title",
"features": ["keyword", "semantic", "relations"],
"rag_strategies": [{"name": "field_extension", "fields": ["a/summary"]}],
"rag_strategies": [
{
"name": "field_extension",
"fields": ["a/summary"],
"data_augmentation_field_prefixes": ["simpson"],
}
],
},
)
assert resp.status_code == 200, resp.text
Expand All @@ -472,13 +504,19 @@ async def test_ask_rag_options_extend_with_fields(

# Matching paragraphs should be in the prompt
# context, plus the extended field for each resource
assert len(prompt_context) == 4
assert len(prompt_context) == 6
# The matching paragraphs
assert prompt_context[f"{resource1}/a/title/0-11"] == "The title 0"
assert prompt_context[f"{resource2}/a/title/0-11"] == "The title 1"
assert prompt_context[f"{rid}/a/title/0-12"] == "The title DA"

# The extended fields
assert prompt_context[f"{resource1}/a/summary"] == "The summary 0"
assert prompt_context[f"{resource2}/a/summary"] == "The summary 1"
assert (
prompt_context[f"{rid}/t/da-simpson-augmented-field"]
== "This is a data augmentation field content"
)


@pytest.mark.parametrize(
Expand Down
20 changes: 11 additions & 9 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,24 +1171,27 @@ def set_discriminator(self) -> Self:
"t": "text",
"f": "file",
"u": "link",
"d": "datetime",
"c": "conversation",
"a": "generic",
}


class FieldExtensionStrategy(RagStrategy):
name: Literal["field_extension"] = "field_extension"
fields: list[str] = Field(
default=[],
title="Fields",
description="List of field ids to extend the context with. It will try to extend the retrieval context with the specified fields in the matching resources. The field ids have to be in the format `{field_type}/{field_name}`, like 'a/title', 'a/summary' for title and summary fields or 't/amend' for a text field named 'amend'.", # noqa: E501
min_length=1,
description="List of field ids to extend the context with. It will try to extend the retrieval context with the specified fields in the matching resources. The field ids have to be in the format `{field_type}/{field_name}`, like 'a/title', 'a/summary' for title and summary fields or 't/amend' for a text field named 'amend'.",
)
data_augmentation_field_prefixes: list[str] = Field(
default=[],
description="List of prefixes for data augmentation added fields to extend the context with. For example, if the prefix is 'simpson', all fields that are a result of data augmentation with that prefix will be used to extend the context.",
)

@field_validator("fields", mode="after")
@classmethod
def fields_validator(cls, fields) -> Self:
@model_validator(mode="after")
def field_extension_strategy_validator(self) -> Self:
# Check that the fields are in the format {field_type}/{field_name}
for field in fields:
for field in self.fields:
try:
field_type, _ = field.strip("/").split("/")
except ValueError:
Expand All @@ -1201,8 +1204,7 @@ def fields_validator(cls, fields) -> Self:
f"Field '{field}' does not have a valid field type. "
f"Valid field types are: {allowed_field_types_part}."
)

return fields
return self


class FullResourceApplyTo(BaseModel):
Expand Down
Loading