Skip to content

Commit 51930e3

Browse files
feat(ai): add cross-region inference support for conversation routes (#2951)
1 parent f8d2b02 commit 51930e3

30 files changed

+2247
-69
lines changed

.changeset/yellow-corners-like.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
'@aws-amplify/ai-constructs': minor
3+
'@aws-amplify/backend-ai': minor
4+
---
5+
6+
feat(ai): add cross-region inference support for AI conversation routes
7+
8+
- Add `AiModelArnGeneratorConstruct` + `AiModelPropsResolver` for region-aware model/inference-profile resolution.
9+
- Conversation handler supports `crossRegionInference` and uses generated ARNs for IAM policies.
10+
- Runtime (`BedrockConverseAdapter`) resolves foundation model vs inference profile ID using `AiModelPropsResolver`.
11+
- Add `models.json` manifest and tests; update public exports.

package-lock.json

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/ai-constructs/API.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@ import * as jsonSchemaToTypeScript from 'json-schema-to-ts';
1414
import { ResourceProvider } from '@aws-amplify/plugin-types';
1515
import { RetentionDays } from 'aws-cdk-lib/aws-logs';
1616

17+
declare namespace __export__ai_model {
18+
export {
19+
AiModelArnGeneratorConstruct,
20+
AiModelPropsResolver,
21+
AiModelConfig
22+
}
23+
}
24+
export { __export__ai_model }
25+
26+
declare namespace __export__ai_model__runtime {
27+
export {
28+
AiModelPropsResolver,
29+
AiModelConfig
30+
}
31+
}
32+
export { __export__ai_model__runtime }
33+
1734
declare namespace __export__conversation {
1835
export {
1936
ConversationHandlerFunction,
@@ -38,6 +55,33 @@ declare namespace __export__conversation__runtime {
3855
}
3956
export { __export__conversation__runtime }
4057

58+
// @public
59+
class AiModelArnGeneratorConstruct extends Construct {
60+
constructor(scope: Construct, id?: string);
61+
generateArns(modelConfig: AiModelConfig): string[];
62+
}
63+
64+
// @public
65+
type AiModelConfig = {
66+
modelId: string;
67+
region: string;
68+
crossRegionInference: boolean;
69+
};
70+
71+
// @public
72+
class AiModelPropsResolver {
73+
constructor();
74+
getFoundationModelId(inferenceProfileId: string): string;
75+
getGeography(region: string): string;
76+
getInferenceProfileId(foundationModelId: string, geography: string): string;
77+
getSupportedSourceRegions(modelId: string, geography: string): string[];
78+
isKnownInferenceProfile(modelId: string): boolean;
79+
requiresCri(modelId: string, region: string): boolean;
80+
resolveModelId(modelConfig: AiModelConfig): string;
81+
supportsCri(modelId: string, region: string): boolean;
82+
validateModelId(modelId: string): void;
83+
}
84+
4185
// @public
4286
class ConversationHandlerFunction extends Construct implements ResourceProvider<FunctionResources> {
4387
constructor(scope: Construct, id: string, props: ConversationHandlerFunctionProps);
@@ -52,6 +96,7 @@ type ConversationHandlerFunctionProps = {
5296
entry?: string;
5397
models: Array<{
5498
modelId: string;
99+
crossRegionInference?: boolean;
55100
region?: string;
56101
}>;
57102
memoryMB?: number;
@@ -76,6 +121,7 @@ type ConversationTurnEvent = {
76121
graphqlApiEndpoint: string;
77122
modelConfiguration: {
78123
modelId: string;
124+
crossRegionInference?: boolean;
79125
systemPrompt: string;
80126
region?: string;
81127
inferenceConfiguration?: {

packages/ai-constructs/package.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
"./conversation/runtime": {
1818
"types": "./lib/conversation/runtime/index.d.ts",
1919
"default": "./lib/conversation/runtime/index.js"
20+
},
21+
"./ai-model": {
22+
"types": "./lib/ai-model/index.d.ts",
23+
"default": "./lib/ai-model/index.js"
24+
},
25+
"./ai-model/runtime": {
26+
"types": "./lib/ai-model/runtime/index.d.ts",
27+
"default": "./lib/ai-model/runtime/index.js"
2028
}
2129
},
2230
"main": "lib/index.js",
@@ -36,7 +44,9 @@
3644
},
3745
"devDependencies": {
3846
"@aws-amplify/backend-output-storage": "^1.3.1",
47+
"@types/aws-lambda": "^8.10.119",
3948
"@types/lodash.transform": "^4.6.9",
49+
"aws-lambda": "^1.0.7",
4050
"lodash.transform": "^4.6.0",
4151
"typescript": "^5.0.0"
4252
},
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import { describe, it } from 'node:test';
2+
import assert from 'node:assert';
3+
import { App, Stack } from 'aws-cdk-lib';
4+
import { Template } from 'aws-cdk-lib/assertions';
5+
import { AiModelArnGeneratorConstruct } from './ai_model_arn_generator_construct';
6+
import { TEST_MODEL_IDS } from './test-assets';
7+
8+
void describe('AiModelArnGeneratorConstruct', () => {
9+
const { FOUNDATION } = TEST_MODEL_IDS;
10+
11+
const createStack = () => {
12+
const app = new App();
13+
return new Stack(app, 'TestStack');
14+
};
15+
16+
void describe('constructor', () => {
17+
void it('creates construct with default id', () => {
18+
const stack = createStack();
19+
const construct = new AiModelArnGeneratorConstruct(stack);
20+
21+
assert.ok(construct);
22+
assert.strictEqual(construct.node.id, 'AmplifyAiModelArnGenerator');
23+
});
24+
25+
void it('creates construct with custom id', () => {
26+
const stack = createStack();
27+
const construct = new AiModelArnGeneratorConstruct(stack, 'CustomId');
28+
29+
assert.ok(construct);
30+
assert.strictEqual(construct.node.id, 'CustomId');
31+
});
32+
33+
void it('creates Lambda function', () => {
34+
const stack = createStack();
35+
new AiModelArnGeneratorConstruct(stack);
36+
37+
const template = Template.fromStack(stack);
38+
39+
// Should create Lambda function
40+
template.hasResourceProperties('AWS::Lambda::Function', {
41+
Runtime: 'nodejs20.x',
42+
Timeout: 30,
43+
MemorySize: 128,
44+
Description:
45+
'Resolve Amazon Bedrock AI model ARNs based on region and cross-region inference settings',
46+
});
47+
});
48+
});
49+
50+
void describe('generateArns', () => {
51+
void it('returns array of ARN tokens', () => {
52+
const stack = createStack();
53+
const construct = new AiModelArnGeneratorConstruct(stack);
54+
55+
const modelConfig = {
56+
modelId: FOUNDATION,
57+
region: 'us-east-1',
58+
crossRegionInference: false,
59+
};
60+
61+
const arns = construct.generateArns(modelConfig);
62+
63+
assert.ok(Array.isArray(arns));
64+
assert.strictEqual(arns.length, 1);
65+
});
66+
67+
void it('returns array when generating ARNs', () => {
68+
const stack = createStack();
69+
const construct = new AiModelArnGeneratorConstruct(stack);
70+
71+
const modelConfig = {
72+
modelId: FOUNDATION,
73+
region: 'us-east-1',
74+
crossRegionInference: false,
75+
};
76+
77+
const arns = construct.generateArns(modelConfig);
78+
79+
assert.ok(Array.isArray(arns));
80+
assert.strictEqual(arns.length, 1);
81+
});
82+
83+
void it('returns different arrays for different calls', () => {
84+
const stack = createStack();
85+
const construct = new AiModelArnGeneratorConstruct(stack);
86+
87+
const arns1 = construct.generateArns({
88+
modelId: FOUNDATION,
89+
region: 'us-east-1',
90+
crossRegionInference: false,
91+
});
92+
93+
const arns2 = construct.generateArns({
94+
modelId: FOUNDATION,
95+
region: 'us-west-2',
96+
crossRegionInference: true,
97+
});
98+
99+
assert.ok(Array.isArray(arns1));
100+
assert.ok(Array.isArray(arns2));
101+
assert.notStrictEqual(arns1, arns2);
102+
});
103+
});
104+
});
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import { Construct } from 'constructs';
2+
import { NodejsFunction } from 'aws-cdk-lib/aws-lambda-nodejs';
3+
import { Runtime as LambdaRuntime } from 'aws-cdk-lib/aws-lambda';
4+
import { CustomResource, Duration, Fn } from 'aws-cdk-lib';
5+
import { Provider } from 'aws-cdk-lib/custom-resources';
6+
import { RetentionDays } from 'aws-cdk-lib/aws-logs';
7+
import { resolve } from 'path';
8+
import { AiModelConfig } from './ai_model_types';
9+
10+
/**
11+
* Custom resource type identifier.
12+
*/
13+
const AI_MODEL_ARN_GENERATOR_RESOURCE_TYPE =
14+
'Custom::AmplifyAiModelArnGenerator';
15+
16+
const DEFAULT_LAMBDA_TIMEOUT_SECONDS = 30;
17+
const DEFAULT_LAMBDA_MEMORY_SIZE = 128;
18+
const DEFAULT_LOG_RETENTION_DAYS = RetentionDays.TWO_WEEKS;
19+
20+
/**
21+
* CDK construct that generates AI model ARNs using a custom resource.
22+
* Handles both foundation models and inference profiles based on configuration.
23+
*/
24+
export class AiModelArnGeneratorConstruct extends Construct {
25+
private readonly provider: Provider;
26+
private resourceCounter = 0;
27+
28+
/**
29+
* Creates a new AI model ARN generator construct.
30+
*/
31+
constructor(scope: Construct, id = 'AmplifyAiModelArnGenerator') {
32+
super(scope, id);
33+
34+
const entryPath = resolve(__dirname, 'lambda', 'ai_model_arn_generator.js');
35+
36+
const onEvent = new NodejsFunction(this, 'AiModelArnGeneratorLambda', {
37+
runtime: LambdaRuntime.NODEJS_20_X,
38+
timeout: Duration.seconds(DEFAULT_LAMBDA_TIMEOUT_SECONDS),
39+
entry: entryPath,
40+
handler: 'handler',
41+
description:
42+
'Resolve Amazon Bedrock AI model ARNs based on region and cross-region inference settings',
43+
memorySize: DEFAULT_LAMBDA_MEMORY_SIZE,
44+
logRetention: DEFAULT_LOG_RETENTION_DAYS,
45+
bundling: {
46+
// Include deps to avoid relying on Lambda layer versions.
47+
externalModules: [],
48+
},
49+
});
50+
51+
this.provider = new Provider(this, 'AiModelArnGeneratorProvider', {
52+
onEventHandler: onEvent,
53+
logRetention: DEFAULT_LOG_RETENTION_DAYS,
54+
});
55+
}
56+
57+
/**
58+
* Generates model ARNs based on the provided configuration.
59+
* Returns CloudFormation tokens that resolve to an array of ARN strings.
60+
*/
61+
public generateArns(modelConfig: AiModelConfig): string[] {
62+
this.resourceCounter += 1;
63+
const resourceId = `AiModelArnsResource${this.resourceCounter}`;
64+
65+
const resource = new CustomResource(this, resourceId, {
66+
serviceToken: this.provider.serviceToken,
67+
resourceType: AI_MODEL_ARN_GENERATOR_RESOURCE_TYPE,
68+
properties: {
69+
modelConfig,
70+
},
71+
});
72+
73+
return Fn.split(',', resource.getAtt('modelArns').toString());
74+
}
75+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/** Mapping of geography names to their regions. */
2+
export type RegionGroups = Record<string, string[]>;
3+
4+
/** Configuration for a model within a specific geography. */
5+
export type AiModelMapping = {
6+
/** Inference profile ID for cross-region inference, null if not available. */
7+
inferenceProfileId: string | null;
8+
/** Regions that can source requests for this model using cross-region inference. */
9+
supportsCriFrom: string[];
10+
/** Regions where cross-region inference is required. */
11+
requiresCriIn: string[];
12+
};
13+
14+
/** Model entry containing geography-specific configurations. */
15+
export type AiModelsEntry = {
16+
geographies: Record<string, AiModelMapping>;
17+
};
18+
19+
/** Root structure of the models JSON data. */
20+
export type AiModelsManifest = {
21+
schemaVersion: 1;
22+
lastUpdated: string;
23+
regionGroups: RegionGroups;
24+
models: Record<string, AiModelsEntry>;
25+
};
26+
27+
/** Configuration for AI model ARN and ID generation. */
28+
export type AiModelConfig = {
29+
modelId: string;
30+
region: string;
31+
crossRegionInference: boolean;
32+
};
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export { AiModelArnGeneratorConstruct } from './ai_model_arn_generator_construct';
2+
export { AiModelPropsResolver } from './runtime/ai_model_props_resolver';
3+
export type { AiModelConfig } from './ai_model_types';
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"rules": {
3+
"no-console": "off"
4+
}
5+
}

0 commit comments

Comments
 (0)