diff --git a/src/Portal.js b/src/Portal.js index e68549f0..2bc22b1d 100644 --- a/src/Portal.js +++ b/src/Portal.js @@ -1,22 +1,87 @@ -import { useEffect, useRef } from 'react' -import { createPortal } from 'react-dom' +import { + useMemo, + useLayoutEffect, + useEffect, + forwardRef, + useState, + isValidElement, + cloneElement, +} from 'react' +import { findDOMNode, createPortal } from 'react-dom' -function Portal({ children }) { - let ref = useRef(null) - - if (ref.current === null) { - ref.current = document.createElement('div') - ref.current.setAttribute('id', '___reactour') +function setRef(ref, value) { + if (typeof ref === 'function') { + ref(value) + } else if (ref) { + ref.current = value } +} - useEffect(() => { - document.body.appendChild(ref.current) - return () => { - document.body.removeChild(ref.current) +function useForkRef(refA, refB) { + /** + * This will create a new function if the ref props change and are defined. + * This means react will call the old forkRef with `null` and the new forkRef + * with the ref. Cleanup naturally emerges from this behavior + */ + return useMemo(() => { + if (refA == null && refB == null) { + return null } - }, [ref]) + return (refValue) => { + setRef(refA, refValue) + setRef(refB, refValue) + } + }, [refA, refB]) +} - return createPortal(children, ref.current) +function getContainer(container) { + container = typeof container === 'function' ? container() : container + // #StrictMode ready + return findDOMNode(container) } +const useEnhancedEffect = + typeof window !== 'undefined' ? useLayoutEffect : useEffect + +/** + * Portals provide a first-class way to render children into a DOM node + * that exists outside the DOM hierarchy of the parent component. + */ +const Portal = forwardRef(function Portal(props, ref) { + const { children, container, disablePortal = false } = props + const [mountNode, setMountNode] = useState(null) + const handleRef = useForkRef( + isValidElement(children) ? children.ref : null, + ref + ) + + useEnhancedEffect(() => { + if (!disablePortal) { + setMountNode(getContainer(container) || document.body) + } + }, [container, disablePortal]) + + useEnhancedEffect(() => { + if (mountNode && !disablePortal) { + setRef(ref, mountNode) + return () => { + setRef(ref, null) + } + } + + return undefined + }, [ref, mountNode, disablePortal]) + + if (disablePortal) { + if (isValidElement(children)) { + return cloneElement(children, { + ref: handleRef, + }) + } + return children + } + + return mountNode ? createPortal(children, mountNode) : mountNode +}) + export default Portal diff --git a/src/Tour.js b/src/Tour.js index 0973a2ac..a980ebef 100644 --- a/src/Tour.js +++ b/src/Tour.js @@ -378,6 +378,8 @@ class Tour extends Component { CustomHelper, disableFocusLock, highlightedBorder, + container, + disablePortal, } = this.props const { @@ -399,7 +401,7 @@ class Tour extends Component { if (isOpen) { return ( - +