Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 177 additions & 41 deletions internal/engines/entity_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ type EntityFilter struct {
// dataReader is responsible for reading relationship information
dataReader storage.DataReader

schema *base.SchemaDefinition
graph *schema.LinkedSchemaGraph
}

// NewEntityFilter creates a new EntityFilter engine
func NewEntityFilter(dataReader storage.DataReader, sch *base.SchemaDefinition) *EntityFilter {
return &EntityFilter{
dataReader: dataReader,
schema: sch,
graph: schema.NewLinkedGraph(sch),
}
}

Expand Down Expand Up @@ -56,9 +56,8 @@ func (engine *EntityFilter) EntityFilter(
}

// Retrieve linked entrances
cn := schema.NewLinkedGraph(engine.schema) // Create a new linked graph from the schema definition.
var entrances []*schema.LinkedEntrance
entrances, err = cn.LinkedEntrances(
entrances, err = engine.graph.LinkedEntrances(
request.GetEntrance(),
&base.Entrance{
Type: request.GetSubject().GetType(),
Expand Down Expand Up @@ -107,6 +106,11 @@ func (engine *EntityFilter) EntityFilter(
if err != nil {
return err
}
case schema.PathChainLinkedEntrance: // If the linked entrance is a path chain entrance.
err = engine.pathChainEntrance(cont, request, entrance, visits, publisher) // Call the path chain entrance method.
if err != nil {
return err
}
default:
return errors.New("unknown linked entrance type") // Return an error if the linked entrance is of an unknown type.
}
Expand All @@ -123,28 +127,19 @@ func (engine *EntityFilter) attributeEntrance(
visits *VisitsMap, // A map that keeps track of visited entities to avoid infinite loops.
publisher *BulkEntityPublisher, // A custom publisher that publishes results in bulk.
) error { // Returns an error if one occurs during execution.
if request.GetEntrance().GetType() != entrance.TargetEntrance.GetType() {
// attributeEntrance only handles direct attribute access
if !visits.AddEA(entrance.TargetEntrance.GetType(), entrance.TargetEntrance.GetValue()) {
return nil
}

if !visits.AddEA(entrance.TargetEntrance.GetType(), entrance.TargetEntrance.GetValue()) { // If the entity and relation has already been visited.
return nil
}

// Retrieve the scope associated with the target entrance type.
// Check if it exists to avoid accessing a nil map entry.
// Retrieve the scope associated with the target entrance type
scope, exists := request.GetScope()[entrance.TargetEntrance.GetType()]

// Initialize data as an empty slice of strings.
var data []string

// If the scope exists, assign its Data field to the data slice.
if exists {
data = scope.GetData()
}

// Define a TupleFilter. This specifies which tuples we're interested in.
// We want tuples that match the entity type and ID from the request, and have a specific relation.
// Query attributes directly
filter := &base.AttributeFilter{
Entity: &base.EntityFilter{
Type: entrance.TargetEntrance.GetType(),
Expand All @@ -153,52 +148,36 @@ func (engine *EntityFilter) attributeEntrance(
Attributes: []string{entrance.TargetEntrance.GetValue()},
}

var (
cti, rit *database.AttributeIterator
err error
pagination database.CursorPagination
)
pagination := database.NewCursorPagination(database.Cursor(request.GetCursor()), database.Sort("entity_id"))

pagination = database.NewCursorPagination(database.Cursor(request.GetCursor()), database.Sort("entity_id"))

// Query the relationships using the specified pagination settings.
// The context tuples are filtered based on the provided filter.
cti, err = storageContext.NewContextualAttributes(request.GetContext().GetAttributes()...).QueryAttributes(filter, pagination)
cti, err := storageContext.NewContextualAttributes(request.GetContext().GetAttributes()...).QueryAttributes(filter, pagination)
if err != nil {
return err
}

// Query the relationships for the entity in the request.
// The results are filtered based on the provided filter and pagination settings.
rit, err = engine.dataReader.QueryAttributes(ctx, request.GetTenantId(), filter, request.GetMetadata().GetSnapToken(), pagination)
rit, err := engine.dataReader.QueryAttributes(ctx, request.GetTenantId(), filter, request.GetMetadata().GetSnapToken(), pagination)
if err != nil {
return err
}

// Create a new UniqueTupleIterator from the two TupleIterators.
// NewUniqueTupleIterator() ensures that the iterator only returns unique tuples.
it := database.NewUniqueAttributeIterator(rit, cti)

// Iterate over the relationships.
// Publish entities directly for regular case
for it.HasNext() {
// Get the next attribute's entity.
current, ok := it.GetNext()
if !ok {
break
}

// Extract the entity details.
entity := &base.Entity{
Type: entrance.TargetEntrance.GetType(), // Example: using the type from a previous variable 'entrance'
Type: entrance.TargetEntrance.GetType(),
Id: current.GetEntity().GetId(),
}

// Check if the entity has already been visited to prevent processing it again.
if !visits.AddPublished(entity) {
continue // Skip this entity if it has already been visited.
continue
}

// Publish the entity with its metadata.
publisher.Publish(entity, &base.PermissionCheckRequestMetadata{
SnapToken: request.GetMetadata().GetSnapToken(),
SchemaVersion: request.GetMetadata().GetSchemaVersion(),
Expand Down Expand Up @@ -407,9 +386,8 @@ func (engine *EntityFilter) lt(
var err error

// Retrieve linked entrances
cn := schema.NewLinkedGraph(engine.schema)
var entrances []*schema.LinkedEntrance
entrances, err = cn.LinkedEntrances(
entrances, err = engine.graph.LinkedEntrances(
request.GetEntrance(),
&base.Entrance{
Type: request.GetSubject().GetType(),
Expand Down Expand Up @@ -452,3 +430,161 @@ func (engine *EntityFilter) lt(
})
return nil
}

// pathChainEntrance handles multi-hop relation chain traversal for nested attributes
//
// TODO: This function can be optimized for better performance by implementing smart batching logic:
// - Extract unique attributes from path chain entrances to avoid duplicate queries
// - Implement batch vs individual processing based on scope and attribute count:
// - Use batch mode when we have scope (limited entity IDs) or few attributes (<=1)
// - Use individual mode when no scope and multiple attributes to avoid loading large result sets
// - Refactor into smaller helper functions: extractUniqueAttributes, getScopeIds, shouldUseBatchMode,
// processBatchMode, processIndividualMode, queryAttributesBatch, processEntranceWithResults
// - Remove debug statements after optimization is tested
func (engine *EntityFilter) pathChainEntrance(
ctx context.Context,
request *base.PermissionEntityFilterRequest,
entrance *schema.LinkedEntrance,
visits *VisitsMap,
publisher *BulkEntityPublisher,
) error {
if !visits.AddEA(entrance.TargetEntrance.GetType(), entrance.TargetEntrance.GetValue()) {
return nil
}

// 1. Query attributes of the target type with scope optimization
scope, exists := request.GetScope()[entrance.TargetEntrance.GetType()]
var data []string
if exists {
data = scope.GetData()
}

filter := &base.AttributeFilter{
Entity: &base.EntityFilter{
Type: entrance.TargetEntrance.GetType(),
Ids: data,
},
Attributes: []string{entrance.TargetEntrance.GetValue()},
}

pagination := database.NewCursorPagination()
cti, err := storageContext.NewContextualAttributes(request.GetContext().GetAttributes()...).QueryAttributes(filter, pagination)
if err != nil {
return err
}

rit, err := engine.dataReader.QueryAttributes(ctx, request.GetTenantId(), filter, request.GetMetadata().GetSnapToken(), pagination)
if err != nil {
return err
}

it := database.NewUniqueAttributeIterator(rit, cti)

// 2. Collect all attribute entity IDs first (batch approach)
var attributeEntityIds []string
sourceType := request.GetEntrance().GetType()
targetType := entrance.TargetEntrance.GetType()

// Collect all entity IDs that have the attribute
for it.HasNext() {
current, ok := it.GetNext()
if !ok {
break
}
attributeEntityIds = append(attributeEntityIds, current.GetEntity().GetId())
}

if len(attributeEntityIds) == 0 {
return nil
}

// 3. Use the PathChain from entrance to traverse relation chain
chain := entrance.PathChain
if len(chain) == 0 {
return errors.New("PathChainLinkedEntrance missing PathChain")
}

// 4. Fold IDs across the relation chain from attribute type back to source type
currentType := targetType
currentIds := attributeEntityIds

for i := len(chain) - 1; i >= 0; i-- {
walk := chain[i] // walk.Type == left entity type; walk.Relation relates walk.Type -> currentType

// Apply scope optimization only on the final walk (source type)
var entIds []string
if i == 0 {
if s, exists := request.GetScope()[sourceType]; exists {
entIds = s.GetData()
}
}

// Determine correct subject relation for complex cases like @group#member
subjectRelation := engine.graph.GetSubjectRelationForPathWalk(walk.GetType(), walk.GetRelation(), currentType)

relationFilter := &base.TupleFilter{
Entity: &base.EntityFilter{
Type: walk.GetType(),
Ids: entIds,
},
Relation: walk.GetRelation(),
Subject: &base.SubjectFilter{
Type: currentType,
Ids: currentIds,
Relation: subjectRelation, // Fixed: Use correct subject relation for complex cases
},
}

pagination := database.NewCursorPagination()
ctiR, err := storageContext.NewContextualTuples(request.GetContext().GetTuples()...).QueryRelationships(relationFilter, pagination)
if err != nil {
return err
}

ritR, err := engine.dataReader.QueryRelationships(ctx, request.GetTenantId(), relationFilter, request.GetMetadata().GetSnapToken(), pagination)
if err != nil {
return err
}

relationIt := database.NewUniqueTupleIterator(ritR, ctiR)

// collect next frontier (left entity IDs)
nextIdsSet := make(map[string]struct{})
for relationIt.HasNext() {
tuple, ok := relationIt.GetNext()
if !ok {
break
}
nextIdsSet[tuple.GetEntity().GetId()] = struct{}{}
}

var nextIds []string
for id := range nextIdsSet {
nextIds = append(nextIds, id)
}

if len(nextIdsSet) == 0 {
return nil // No path found through this walk
}

// prepare for next walk
currentType = walk.GetType()
currentIds = nextIds
}

// 5. Publish all resolved source entities
for _, id := range currentIds {
entity := &base.Entity{Type: sourceType, Id: id}
if !visits.AddPublished(entity) {
continue
}

publisher.Publish(entity, &base.PermissionCheckRequestMetadata{
SnapToken: request.GetMetadata().GetSnapToken(),
SchemaVersion: request.GetMetadata().GetSchemaVersion(),
Depth: request.GetMetadata().GetDepth(),
}, request.GetContext(), base.CheckResult_CHECK_RESULT_UNSPECIFIED)
}

return nil
}
Loading
Loading