diff --git a/providers/universal-provider/src/UniversalProvider.ts b/providers/universal-provider/src/UniversalProvider.ts index de6f421cff..4c49563391 100644 --- a/providers/universal-provider/src/UniversalProvider.ts +++ b/providers/universal-provider/src/UniversalProvider.ts @@ -160,7 +160,7 @@ export class UniversalProvider implements IUniversalProvider { // assign namespaces from session if not already defined const approved = populateNamespacesChains(this.session.namespaces) as NamespaceConfig; this.namespaces = mergeRequiredOptionalNamespaces(this.namespaces, approved); - this.persist("namespaces", this.namespaces); + await this.persist("namespaces", this.namespaces); this.onConnect(); } return result; @@ -205,7 +205,8 @@ export class UniversalProvider implements IUniversalProvider { // assign namespaces from session if not already defined const approved = populateNamespacesChains(session.namespaces) as NamespaceConfig; this.namespaces = mergeRequiredOptionalNamespaces(this.namespaces, approved); - this.persist("namespaces", this.namespaces); + await this.persist("namespaces", this.namespaces); + await this.persist("optionalNamespaces", this.optionalNamespaces); this.onConnect(); return this.session; @@ -253,13 +254,9 @@ export class UniversalProvider implements IUniversalProvider { // ---------- Private ----------------------------------------------- // private async checkStorage() { - this.namespaces = await this.getFromStore("namespaces"); - this.optionalNamespaces = (await this.getFromStore("optionalNamespaces")) || {}; - if (this.client.session.length) { - const lastKeyIndex = this.client.session.keys.length - 1; - this.session = this.client.session.get(this.client.session.keys[lastKeyIndex]); - this.createProviders(); - } + this.namespaces = (await this.getFromStore(`namespaces`)) || {}; + this.optionalNamespaces = (await this.getFromStore(`optionalNamespaces`)) || {}; + if (this.session) this.createProviders(); } private async initialize() { @@ -285,6 +282,19 @@ export class UniversalProvider implements IUniversalProvider { telemetryEnabled: this.providerOpts.telemetryEnabled, })); + if (this.providerOpts.session) { + try { + this.session = this.client.session.get(this.providerOpts.session.topic); + } catch (error) { + this.logger.error("Failed to get session", error); + throw new Error( + `The provided session: ${this.providerOpts?.session?.topic} doesn't exist in the Sign client`, + ); + } + } else { + const sessions = this.client.session.getAll(); + this.session = sessions[0]; + } this.logger.trace(`SignClient Initialized`); } @@ -389,11 +399,14 @@ export class UniversalProvider implements IUniversalProvider { } this.client.on("session_ping", (args) => { + const { topic } = args; + if (topic !== this.session?.topic) return; this.events.emit("session_ping", args); }); this.client.on("session_event", (args) => { - const { params } = args; + const { params, topic } = args; + if (topic !== this.session?.topic) return; const { event } = params; if (event.name === "accountsChanged") { const accounts = event.data; @@ -419,6 +432,7 @@ export class UniversalProvider implements IUniversalProvider { }); this.client.on("session_update", ({ topic, params }) => { + if (topic !== this.session?.topic) return; const { namespaces } = params; const _session = this.client?.session.get(topic); this.session = { ..._session, namespaces } as SessionTypes.Struct; @@ -427,6 +441,7 @@ export class UniversalProvider implements IUniversalProvider { }); this.client.on("session_delete", async (payload) => { + if (payload.topic !== this.session?.topic) return; await this.cleanup(); this.events.emit("session_delete", payload); this.events.emit("disconnect", { @@ -463,8 +478,6 @@ export class UniversalProvider implements IUniversalProvider { } this.sessionProperties = sessionProperties; this.scopedProperties = scopedProperties; - this.persist("namespaces", namespaces); - this.persist("optionalNamespaces", optionalNamespaces); } private validateChain(chain?: string): [string, string] { @@ -497,7 +510,7 @@ export class UniversalProvider implements IUniversalProvider { return await this.getProvider(namespace).requestAccounts(); } - private onChainChanged(caip2Chain: string, internal = false): void { + private async onChainChanged(caip2Chain: string, internal = false): Promise { if (!this.namespaces) return; const [namespace, chainId] = this.validateChain(caip2Chain); @@ -517,8 +530,8 @@ export class UniversalProvider implements IUniversalProvider { this.namespaces[`${namespace}:${chainId}`] = { defaultChain: chainId }; } - this.persist("namespaces", this.namespaces); this.events.emit("chainChanged", chainId); + await this.persist("namespaces", this.namespaces); } private onConnect() { @@ -527,23 +540,48 @@ export class UniversalProvider implements IUniversalProvider { } private async cleanup() { - this.session = undefined; this.namespaces = undefined; this.optionalNamespaces = undefined; this.sessionProperties = undefined; - this.scopedProperties = undefined; - this.persist("namespaces", undefined); - this.persist("optionalNamespaces", undefined); - this.persist("sessionProperties", undefined); + await this.deleteFromStore("namespaces"); + await this.deleteFromStore("optionalNamespaces"); + await this.deleteFromStore("sessionProperties"); + // reset the session after removing from store as the topic is used there + this.session = undefined; await this.cleanupPendingPairings({ deletePairings: true }); + await this.cleanupStorage(); } - private persist(key: string, data: unknown) { - this.client.core.storage.setItem(`${STORAGE}/${key}`, data); + private async persist(key: string, data: unknown) { + const topic = this.session?.topic || ""; + await this.client.core.storage.setItem(`${STORAGE}/${key}${topic}`, data); } private async getFromStore(key: string) { - return await this.client.core.storage.getItem(`${STORAGE}/${key}`); + const topic = this.session?.topic || ""; + return await this.client.core.storage.getItem(`${STORAGE}/${key}${topic}`); + } + + private async deleteFromStore(key: string) { + const topic = this.session?.topic || ""; + await this.client.core.storage.removeItem(`${STORAGE}/${key}${topic}`); + } + + // remove all storage items if there are no sessions left + private async cleanupStorage() { + try { + if (this.client?.session.length > 0) { + return; + } + const keys = await this.client.core.storage.getKeys(); + for (const key of keys) { + if (key.startsWith(STORAGE)) { + await this.client.core.storage.removeItem(key); + } + } + } catch (error) { + this.logger.warn("Failed to cleanup storage", error); + } } } export default UniversalProvider; diff --git a/providers/universal-provider/src/types/misc.ts b/providers/universal-provider/src/types/misc.ts index 0303c8d065..d5be63179d 100644 --- a/providers/universal-provider/src/types/misc.ts +++ b/providers/universal-provider/src/types/misc.ts @@ -6,6 +6,9 @@ import { IEvents } from "@walletconnect/events"; import { Logger } from "@walletconnect/logger"; import { IProvider } from "./providers"; +/** + * @param session - The session to use. If not provided, the provider will create a new session. + */ export interface UniversalProviderOpts extends SignClientTypes.Options { projectId?: string; metadata?: Metadata; @@ -16,6 +19,7 @@ export interface UniversalProviderOpts extends SignClientTypes.Options { storage?: IKeyValueStorage; name?: string; disableProviderPing?: boolean; + session?: SessionTypes.Struct; } export type Metadata = SignClientTypes.Metadata; diff --git a/providers/universal-provider/test/index.spec.ts b/providers/universal-provider/test/index.spec.ts index fe3470a3ac..789839b9a2 100644 --- a/providers/universal-provider/test/index.spec.ts +++ b/providers/universal-provider/test/index.spec.ts @@ -361,20 +361,20 @@ describe("UniversalProvider", function () { describe("persistence", () => { describe("after restart", () => { it("clients can ping each other", async () => { + const dappDbName = getDbName(`dappDB-${Date.now()}`); + const walletDbName = getDbName(`walletDB-${Date.now()}`); const dapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "dapp", - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, }); const wallet = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "wallet", - storageOptions: { database: getDbName("walletDB") }, + storageOptions: { database: walletDbName }, }); const chains = [`eip155:${CHAIN_ID}`, `eip155:${CHAIN_ID_B}`]; - const { - sessionA: { topic }, - } = await testConnectMethod( + const { sessionA } = await testConnectMethod( { dapp, wallet, @@ -392,6 +392,8 @@ describe("UniversalProvider", function () { }, }, ); + wallet.session = sessionA; + const topic = sessionA.topic; await Promise.all([ new Promise((resolve) => { @@ -417,17 +419,16 @@ describe("UniversalProvider", function () { const addresses = (await dapp.request({ method: "eth_accounts" })) as string[]; // delete await deleteProviders({ A: dapp, B: wallet }); - // restart const afterDapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "dapp", - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, }); const afterWallet = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "wallet", - storageOptions: { database: getDbName("walletDB") }, + storageOptions: { database: walletDbName }, }); // ping @@ -445,15 +446,17 @@ describe("UniversalProvider", function () { }); it("should reload provider data after restart", async () => { + const dappDbName = getDbName(`dappDB-${Date.now()}`); + const walletDbName = getDbName(`walletDB-${Date.now()}`); const dapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "dapp", - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, }); const wallet = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "wallet", - storageOptions: { database: getDbName("walletDB") }, + storageOptions: { database: walletDbName }, }); const { @@ -474,7 +477,7 @@ describe("UniversalProvider", function () { const afterDapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, name: "afterDapp", - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, }); // load the provider in ethers without new pairing @@ -769,9 +772,10 @@ describe("UniversalProvider", function () { }); describe("caip validation", () => { it("should reload after restart", async () => { + const dappDbName = getDbName(`dappDB-${Date.now()}`); const dapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, name: "dapp", }); const wallet = await UniversalProvider.init({ @@ -814,7 +818,7 @@ describe("UniversalProvider", function () { // restart const afterDapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, name: "dapp", }); @@ -826,9 +830,10 @@ describe("UniversalProvider", function () { }); }); it("should reload after restart with correct chain", async () => { + const dappDbName = getDbName(`dappDB-${Date.now()}`); const dapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, name: "dapp", }); const wallet = await UniversalProvider.init({ @@ -876,7 +881,7 @@ describe("UniversalProvider", function () { // restart const afterDapp = await UniversalProvider.init({ ...TEST_PROVIDER_OPTS, - storageOptions: { database: getDbName("dappDB") }, + storageOptions: { database: dappDbName }, name: "dapp", });