Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 122 additions & 106 deletions lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1237,11 +1237,10 @@ func parseCxxSpansInSignature(
}

func parseMacroParam(
_ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter,
_ paramExpr: ExprSyntax, _ signature: FunctionSignatureSyntax, _ rewriter: CountExprRewriter,
nonescapingPointers: inout Set<Int>,
lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]]
) throws -> ParamInfo? {
let paramExpr = paramAST.expression
guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else {
throw DiagnosticError(
"expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr)
Expand Down Expand Up @@ -1567,6 +1566,121 @@ func deconstructFunction(_ declaration: some DeclSyntaxProtocol) throws -> Funct
throw DiagnosticError("@_SwiftifyImport only works on functions and initializers", node: declaration)
}

func constructOverloadFunction(forDecl declaration: some DeclSyntaxProtocol, leadingTrivia: Trivia,
args arguments: [ExprSyntax], spanAvailability: String?,
typeMappings: [String: String]?) throws -> DeclSyntax {
let origFuncComponents = try deconstructFunction(declaration)
let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents)

var nonescapingPointers = Set<Int>()
var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:]
var parsedArgs = try arguments.compactMap {
try parseMacroParam(
$0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers,
lifetimeDependencies: &lifetimeDependencies)
}
parsedArgs.append(
contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings))
setNonescapingPointers(&parsedArgs, nonescapingPointers)
setLifetimeDependencies(&parsedArgs, lifetimeDependencies)
// We only transform non-escaping spans.
parsedArgs = parsedArgs.filter {
if let cxxSpanArg = $0 as? CxxSpan {
return cxxSpanArg.nonescaping || cxxSpanArg.pointerIndex == .return
} else {
return true
}
}
try checkArgs(parsedArgs, funcComponents)
parsedArgs.sort { a, b in
// make sure return value cast to Span happens last so that withUnsafeBufferPointer
// doesn't return a ~Escapable type
if a.pointerIndex != .return && b.pointerIndex == .return {
return true
}
if a.pointerIndex == .return && b.pointerIndex != .return {
return false
}
return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex)
}
let baseBuilder = FunctionCallBuilder(funcComponents)

let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce(
baseBuilder,
{ (prev, parsedArg) in
parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents)
})
let newSignature = try builder.buildFunctionSignature([:], nil)
var eliminatedArgs = Set<Int>()
let basicChecks = try builder.buildBasicBoundsChecks(&eliminatedArgs)
let compoundChecks = try builder.buildCompoundBoundsChecks()
let checks = (basicChecks + compoundChecks).map { e in
CodeBlockItemSyntax(leadingTrivia: "\n", item: e)
}
let call: CodeBlockItemSyntax =
if declaration.is(InitializerDeclSyntax.self) {
CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
try builder.buildFunctionCall([:])))
} else {
CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
ReturnStmtSyntax(
returnKeyword: .keyword(.return, trailingTrivia: " "),
expression: try builder.buildFunctionCall([:]))))
}
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies)
let lifetimeAttrs =
returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes)
let availabilityAttr = try getAvailability(newSignature, spanAvailability)
let disfavoredOverload: [AttributeListSyntax.Element] =
[
.attribute(
AttributeSyntax(
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload")))
]
let attributes =
funcComponents.attributes.filter { e in
switch e {
case .attribute(let attr):
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
let name = attr.attributeName.as(IdentifierTypeSyntax.self)?.name.text
return name == nil || (name != "_SwiftifyImport" && name != "_alwaysEmitIntoClient")
default: return true
}
} + [
.attribute(
AttributeSyntax(
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient")))
]
+ availabilityAttr
+ lifetimeAttrs
+ disfavoredOverload
let trivia =
leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n")
if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) {
return DeclSyntax(
origFuncDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))
}
if let origInitDecl = declaration.as(InitializerDeclSyntax.self) {
return DeclSyntax(
origInitDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))
}
throw DiagnosticError(
"Expected function decl or initializer decl, found: \(declaration.kind)", node: declaration)
}

/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
/// Depends on bounds, escapability and lifetime information for each pointer.
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
Expand All @@ -1580,9 +1694,6 @@ public struct SwiftifyImportMacro: PeerMacro {
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
do {
let origFuncComponents = try deconstructFunction(declaration)
let (funcComponents, rewriter) = renameParameterNamesIfNeeded(origFuncComponents)

let argumentList = node.arguments!.as(LabeledExprListSyntax.self)!
var arguments = [LabeledExprSyntax](argumentList)
let typeMappings = try parseTypeMappingParam(arguments.last)
Expand All @@ -1593,107 +1704,12 @@ public struct SwiftifyImportMacro: PeerMacro {
if spanAvailability != nil {
arguments = arguments.dropLast()
}
var nonescapingPointers = Set<Int>()
var lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]] = [:]
var parsedArgs = try arguments.compactMap {
try parseMacroParam(
$0, funcComponents.signature, rewriter, nonescapingPointers: &nonescapingPointers,
lifetimeDependencies: &lifetimeDependencies)
}
parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcComponents.signature, typeMappings))
setNonescapingPointers(&parsedArgs, nonescapingPointers)
setLifetimeDependencies(&parsedArgs, lifetimeDependencies)
// We only transform non-escaping spans.
parsedArgs = parsedArgs.filter {
if let cxxSpanArg = $0 as? CxxSpan {
return cxxSpanArg.nonescaping || cxxSpanArg.pointerIndex == .return
} else {
return true
}
}
try checkArgs(parsedArgs, funcComponents)
parsedArgs.sort { a, b in
// make sure return value cast to Span happens last so that withUnsafeBufferPointer
// doesn't return a ~Escapable type
if a.pointerIndex != .return && b.pointerIndex == .return {
return true
}
if a.pointerIndex == .return && b.pointerIndex != .return {
return false
}
return paramOrReturnIndex(a.pointerIndex) < paramOrReturnIndex(b.pointerIndex)
}
let baseBuilder = FunctionCallBuilder(funcComponents)

let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce(
baseBuilder,
{ (prev, parsedArg) in
parsedArg.getBoundsCheckedThunkBuilder(prev, funcComponents)
})
let newSignature = try builder.buildFunctionSignature([:], nil)
var eliminatedArgs = Set<Int>()
let basicChecks = try builder.buildBasicBoundsChecks(&eliminatedArgs)
let compoundChecks = try builder.buildCompoundBoundsChecks()
let checks = (basicChecks + compoundChecks).map { e in
CodeBlockItemSyntax(leadingTrivia: "\n", item: e)
}
var call : CodeBlockItemSyntax
if declaration.is(InitializerDeclSyntax.self) {
call = CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
try builder.buildFunctionCall([:])))
} else {
call = CodeBlockItemSyntax(
item: CodeBlockItemSyntax.Item(
ReturnStmtSyntax(
returnKeyword: .keyword(.return, trailingTrivia: " "),
expression: try builder.buildFunctionCall([:]))))
}
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcComponents, lifetimeDependencies)
let lifetimeAttrs =
returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcComponents.attributes)
let availabilityAttr = try getAvailability(newSignature, spanAvailability)
let disfavoredOverload: [AttributeListSyntax.Element] =
[
.attribute(
AttributeSyntax(
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload")))
]
let attributes = funcComponents.attributes.filter { e in
switch e {
case .attribute(let attr):
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
let name = attr.attributeName.as(IdentifierTypeSyntax.self)?.name.text
return name == nil || (name != "_SwiftifyImport" && name != "_alwaysEmitIntoClient")
default: return true
}
} + [
.attribute(
AttributeSyntax(
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient")))
]
+ availabilityAttr
+ lifetimeAttrs
+ disfavoredOverload
let trivia = node.leadingTrivia + .docLineComment("/// This is an auto-generated wrapper for safer interop\n")
if let origFuncDecl = declaration.as(FunctionDeclSyntax.self) {
return [DeclSyntax(origFuncDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))]
}
if let origInitDecl = declaration.as(InitializerDeclSyntax.self) {
return [DeclSyntax(origInitDecl
.with(\.signature, newSignature)
.with(\.body, body)
.with(\.attributes, AttributeListSyntax(attributes))
.with(\.leadingTrivia, trivia))]
}
return []
let args = arguments.map { $0.expression }
return [
try constructOverloadFunction(
forDecl: declaration, leadingTrivia: node.leadingTrivia, args: args,
spanAvailability: spanAvailability,
typeMappings: typeMappings)]
} catch let error as DiagnosticError {
context.diagnose(
Diagnostic(
Expand Down