Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions providers/universal-provider/src/UniversalProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ export class UniversalProvider implements IUniversalProvider {
// assign namespaces from session if not already defined
const approved = populateNamespacesChains(this.session.namespaces) as NamespaceConfig;
this.namespaces = mergeRequiredOptionalNamespaces(this.namespaces, approved);
this.persist("namespaces", this.namespaces);
await this.persist("namespaces", this.namespaces);
this.onConnect();
}
return result;
Expand Down Expand Up @@ -202,7 +202,8 @@ export class UniversalProvider implements IUniversalProvider {
// assign namespaces from session if not already defined
const approved = populateNamespacesChains(session.namespaces) as NamespaceConfig;
this.namespaces = mergeRequiredOptionalNamespaces(this.namespaces, approved);
this.persist("namespaces", this.namespaces);
await this.persist("namespaces", this.namespaces);
await this.persist("optionalNamespaces", this.optionalNamespaces);

this.onConnect();
return this.session;
Expand Down Expand Up @@ -250,13 +251,9 @@ export class UniversalProvider implements IUniversalProvider {
// ---------- Private ----------------------------------------------- //

private async checkStorage() {
this.namespaces = await this.getFromStore("namespaces");
this.optionalNamespaces = (await this.getFromStore("optionalNamespaces")) || {};
if (this.client.session.length) {
const lastKeyIndex = this.client.session.keys.length - 1;
this.session = this.client.session.get(this.client.session.keys[lastKeyIndex]);
this.createProviders();
}
this.namespaces = (await this.getFromStore(`namespaces`)) || {};
this.optionalNamespaces = (await this.getFromStore(`optionalNamespaces`)) || {};
if (this.session) this.createProviders();
}

private async initialize() {
Expand All @@ -282,6 +279,19 @@ export class UniversalProvider implements IUniversalProvider {
telemetryEnabled: this.providerOpts.telemetryEnabled,
}));

if (this.providerOpts.session) {
try {
this.session = this.client.session.get(this.providerOpts.session.topic);
} catch (error) {
this.logger.error("Failed to get session", error);
throw new Error(
`The provided session: ${this.providerOpts?.session?.topic} doesn't exist in the Sign client`,
);
}
} else {
const sessions = this.client.session.getAll();
this.session = sessions[0];
}
this.logger.trace(`SignClient Initialized`);
}

Expand Down Expand Up @@ -386,11 +396,14 @@ export class UniversalProvider implements IUniversalProvider {
}

this.client.on("session_ping", (args) => {
const { topic } = args;
if (topic !== this.session?.topic) return;
this.events.emit("session_ping", args);
});

this.client.on("session_event", (args) => {
const { params } = args;
const { params, topic } = args;
if (topic !== this.session?.topic) return;
const { event } = params;
if (event.name === "accountsChanged") {
const accounts = event.data;
Expand All @@ -416,6 +429,7 @@ export class UniversalProvider implements IUniversalProvider {
});

this.client.on("session_update", ({ topic, params }) => {
if (topic !== this.session?.topic) return;
const { namespaces } = params;
const _session = this.client?.session.get(topic);
this.session = { ..._session, namespaces } as SessionTypes.Struct;
Expand All @@ -424,6 +438,7 @@ export class UniversalProvider implements IUniversalProvider {
});

this.client.on("session_delete", async (payload) => {
if (payload.topic !== this.session?.topic) return;
await this.cleanup();
this.events.emit("session_delete", payload);
this.events.emit("disconnect", {
Expand Down Expand Up @@ -459,8 +474,6 @@ export class UniversalProvider implements IUniversalProvider {
this.optionalNamespaces = optionalNamespaces;
}
this.sessionProperties = sessionProperties;
this.persist("namespaces", namespaces);
this.persist("optionalNamespaces", optionalNamespaces);
}

private validateChain(chain?: string): [string, string] {
Expand Down Expand Up @@ -493,7 +506,7 @@ export class UniversalProvider implements IUniversalProvider {
return await this.getProvider(namespace).requestAccounts();
}

private onChainChanged(caip2Chain: string, internal = false): void {
private async onChainChanged(caip2Chain: string, internal = false): Promise<void> {
if (!this.namespaces) return;

const [namespace, chainId] = this.validateChain(caip2Chain);
Expand All @@ -513,8 +526,8 @@ export class UniversalProvider implements IUniversalProvider {
this.namespaces[`${namespace}:${chainId}`] = { defaultChain: chainId };
}

this.persist("namespaces", this.namespaces);
this.events.emit("chainChanged", chainId);
await this.persist("namespaces", this.namespaces);
}

private onConnect() {
Expand All @@ -523,22 +536,47 @@ export class UniversalProvider implements IUniversalProvider {
}

private async cleanup() {
this.session = undefined;
this.namespaces = undefined;
this.optionalNamespaces = undefined;
this.sessionProperties = undefined;
this.persist("namespaces", undefined);
this.persist("optionalNamespaces", undefined);
this.persist("sessionProperties", undefined);
await this.deleteFromStore("namespaces");
await this.deleteFromStore("optionalNamespaces");
await this.deleteFromStore("sessionProperties");
this.session = undefined;
await this.cleanupPendingPairings({ deletePairings: true });
await this.cleanupStorage();
}

private persist(key: string, data: unknown) {
this.client.core.storage.setItem(`${STORAGE}/${key}`, data);
private async persist(key: string, data: unknown) {
const topic = this.session?.topic || "";
await this.client.core.storage.setItem(`${STORAGE}/${key}${topic}`, data);
}

private async getFromStore(key: string) {
return await this.client.core.storage.getItem(`${STORAGE}/${key}`);
const topic = this.session?.topic || "";
return await this.client.core.storage.getItem(`${STORAGE}/${key}${topic}`);
}

private async deleteFromStore(key: string) {
const topic = this.session?.topic || "";
await this.client.core.storage.removeItem(`${STORAGE}/${key}${topic}`);
}

// remove all storage items if there are no sessions left
private async cleanupStorage() {
try {
if (this.client?.session.length > 0) {
return;
}
const keys = await this.client.core.storage.getKeys();
for (const key of keys) {
if (key.startsWith(STORAGE)) {
await this.client.core.storage.removeItem(key);
}
}
} catch (error) {
this.logger.warn("Failed to cleanup storage", error);
}
}
}
export default UniversalProvider;
6 changes: 5 additions & 1 deletion providers/universal-provider/src/types/misc.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import SignClient from "@walletconnect/sign-client";
import { SignClientTypes, ProposalTypes, AuthTypes } from "@walletconnect/types";
import { SignClientTypes, ProposalTypes, AuthTypes, SessionTypes } from "@walletconnect/types";
import { JsonRpcProvider } from "@walletconnect/jsonrpc-provider";
import { KeyValueStorageOptions, IKeyValueStorage } from "@walletconnect/keyvaluestorage";
import { IEvents } from "@walletconnect/events";
import { Logger } from "@walletconnect/logger";
import { IProvider } from "./providers";

/**
* @param session - The session to use. If not provided, the provider will create a new session.
*/
export interface UniversalProviderOpts extends SignClientTypes.Options {
projectId?: string;
metadata?: Metadata;
Expand All @@ -16,6 +19,7 @@ export interface UniversalProviderOpts extends SignClientTypes.Options {
storage?: IKeyValueStorage;
name?: string;
disableProviderPing?: boolean;
session?: SessionTypes.Struct;
}

export type Metadata = SignClientTypes.Metadata;
Expand Down
35 changes: 20 additions & 15 deletions providers/universal-provider/test/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,20 @@ describe("UniversalProvider", function () {
describe("persistence", () => {
describe("after restart", () => {
it("clients can ping each other", async () => {
const dappDbName = getDbName(`dappDB-${Date.now()}`);
const walletDbName = getDbName(`walletDB-${Date.now()}`);
const dapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "dapp",
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
});
const wallet = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "wallet",
storageOptions: { database: getDbName("walletDB") },
storageOptions: { database: walletDbName },
});
const chains = [`eip155:${CHAIN_ID}`, `eip155:${CHAIN_ID_B}`];
const {
sessionA: { topic },
} = await testConnectMethod(
const { sessionA } = await testConnectMethod(
{
dapp,
wallet,
Expand All @@ -392,6 +392,8 @@ describe("UniversalProvider", function () {
},
},
);
wallet.session = sessionA;
const topic = sessionA.topic;

await Promise.all([
new Promise((resolve) => {
Expand All @@ -417,17 +419,16 @@ describe("UniversalProvider", function () {
const addresses = (await dapp.request({ method: "eth_accounts" })) as string[];
// delete
await deleteProviders({ A: dapp, B: wallet });

// restart
const afterDapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "dapp",
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
});
const afterWallet = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "wallet",
storageOptions: { database: getDbName("walletDB") },
storageOptions: { database: walletDbName },
});

// ping
Expand All @@ -445,15 +446,17 @@ describe("UniversalProvider", function () {
});

it("should reload provider data after restart", async () => {
const dappDbName = getDbName(`dappDB-${Date.now()}`);
const walletDbName = getDbName(`walletDB-${Date.now()}`);
const dapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "dapp",
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
});
const wallet = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "wallet",
storageOptions: { database: getDbName("walletDB") },
storageOptions: { database: walletDbName },
});

const {
Expand All @@ -474,7 +477,7 @@ describe("UniversalProvider", function () {
const afterDapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
name: "afterDapp",
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
});

// load the provider in ethers without new pairing
Expand Down Expand Up @@ -769,9 +772,10 @@ describe("UniversalProvider", function () {
});
describe("caip validation", () => {
it("should reload after restart", async () => {
const dappDbName = getDbName(`dappDB-${Date.now()}`);
const dapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
name: "dapp",
});
const wallet = await UniversalProvider.init({
Expand Down Expand Up @@ -814,7 +818,7 @@ describe("UniversalProvider", function () {
// restart
const afterDapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
name: "dapp",
});

Expand All @@ -826,9 +830,10 @@ describe("UniversalProvider", function () {
});
});
it("should reload after restart with correct chain", async () => {
const dappDbName = getDbName(`dappDB-${Date.now()}`);
const dapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
name: "dapp",
});
const wallet = await UniversalProvider.init({
Expand Down Expand Up @@ -876,7 +881,7 @@ describe("UniversalProvider", function () {
// restart
const afterDapp = await UniversalProvider.init({
...TEST_PROVIDER_OPTS,
storageOptions: { database: getDbName("dappDB") },
storageOptions: { database: dappDbName },
name: "dapp",
});

Expand Down
Loading