diff --git a/.asf.yaml b/.asf.yaml index 99fd6fac22c76..0588a300a5ca8 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -54,72 +54,14 @@ github: # needs to be updated as part of the release process # .asf.yaml doesn't support wildcard branch protection rules, only exact branch names # https://github.com/apache/infrastructure-asfyaml?tab=readme-ov-file#branch-protection - # Keeping set of protected branches for future releases - # Meanwhile creating a prerelease script that will update the branch protection names - # automatically. Keep track on it https://github.com/apache/datafusion/issues/17134 + # these branches protection blocks autogenerated during release process which is described in + # https://github.com/apache/datafusion/tree/main/dev/release#2-add-a-protection-to-release-candidate-branch branch-50: required_pull_request_reviews: required_approving_review_count: 1 branch-51: required_pull_request_reviews: required_approving_review_count: 1 - branch-52: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-53: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-54: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-55: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-56: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-57: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-58: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-59: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-60: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-61: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-62: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-63: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-64: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-65: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-66: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-67: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-68: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-69: - required_pull_request_reviews: - required_approving_review_count: 1 - branch-70: - required_pull_request_reviews: - required_approving_review_count: 1 pull_requests: # enable updating head branches of pull requests allow_update_branch: true @@ -129,3 +71,4 @@ github: # https://datafusion.apache.org/ publish: whoami: asf-site + diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 22d2f2187dd07..6228370c955a9 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -46,3 +46,17 @@ runs: # https://github.com/actions/checkout/issues/766 shell: bash run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + - name: Remove unnecessary preinstalled software + shell: bash + run: | + echo "Disk space before cleanup:" + df -h + apt-get clean + # remove tool cache: about 8.5GB (github has host /opt/hostedtoolcache mounted as /__t) + rm -rf /__t/* || true + # remove Haskell runtime: about 6.3GB (host /usr/local/.ghcup) + rm -rf /host/usr/local/.ghcup || true + # remove Android library: about 7.8GB (host /usr/local/lib/android) + rm -rf /host/usr/local/lib/android || true + echo "Disk space after cleanup:" + df -h \ No newline at end of file diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index f269331e83ca7..b09e82bb8602d 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -40,9 +40,9 @@ jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install cargo-audit - uses: taiki-e/install-action@6f9c7cc51aa54b13cbcbd12f8bbf69d8ba405b4b # v2.62.47 + uses: taiki-e/install-action@de7896b7cd1c7d181266425abbe571b5a8c757bc # v2.65.3 with: tool: cargo-audit - name: Run audit check diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index 7e736e1a7afbf..fef65870b697d 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -44,7 +44,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -62,7 +62,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install cargo-machete run: cargo install cargo-machete --version ^0.9 --locked - name: Detect unused dependencies diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index cc879f66cc936..1ec7c16b488f5 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -32,8 +32,9 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install HawkEye + # This CI job is bound by installation time, use `--profile dev` to speed it up run: cargo install hawkeye --version 6.2.0 --locked --profile dev - name: Run license header check run: ci/scripts/license_header.sh @@ -42,18 +43,25 @@ jobs: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 with: node-version: "20" - name: Prettier check - run: | - # if you encounter error, rerun the command below and commit the changes - # - # ignore subproject CHANGELOG.md because they are machine generated - npx prettier@2.7.1 --write \ - '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ - '!datafusion/CHANGELOG.md' \ - README.md \ - CONTRIBUTING.md - git diff --exit-code + # if you encounter error, see instructions inside the script + run: ci/scripts/doc_prettier_check.sh + + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false + # Version fixed on purpose. It uses heuristics to detect typos, so upgrading + # it may cause checks to fail more often. + # We can upgrade it manually once a while. + - name: Install typos-cli + run: cargo install typos-cli --locked --version 1.37.0 + - name: Run typos check + run: ci/scripts/typos_check.sh diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 588bf46aaca70..3e2c48643c366 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,16 +32,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Checkout asf-site branch - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: asf-site path: asf-site - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.12" diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index c182f2ef85d23..81eeb4039ba97 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -40,12 +40,12 @@ jobs: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 - name: Setup Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.12" - name: Install doc dependencies diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 85e40731a9592..01de0d5b77a7a 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -69,7 +69,7 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -93,7 +93,7 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -137,7 +137,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -158,7 +158,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0abf535b9741f..01e21115010fc 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -39,7 +39,7 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Assign GitHub labels if: | diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index 9cbfd6030a7f6..b96b8cd4544ee 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -29,7 +29,7 @@ jobs: check-files: runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 0 - name: Check size of new Git objects diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c57300eec0e4d..2a907ba7e5b14 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -49,13 +49,13 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: shared-key: "amd-ci-check" # this job uses it's own cache becase check has a separate cache and we need it to be fast as it blocks other jobs save-if: ${{ github.ref_name == 'main' }} @@ -77,7 +77,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -102,13 +102,13 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -139,7 +139,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -170,13 +170,13 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -209,8 +209,6 @@ jobs: run: cargo check --profile ci --no-default-features -p datafusion --features=math_expressions - name: Check datafusion (parquet) run: cargo check --profile ci --no-default-features -p datafusion --features=parquet - - name: Check datafusion (pyarrow) - run: cargo check --profile ci --no-default-features -p datafusion --features=pyarrow - name: Check datafusion (regex_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=regex_expressions - name: Check datafusion (recursive_protection) @@ -237,7 +235,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -271,8 +269,10 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust + volumes: + - /usr/local:/host/usr/local steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -281,7 +281,7 @@ jobs: with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci" @@ -318,14 +318,14 @@ jobs: needs: linux-build-lib runs-on: ubuntu-latest steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain run: rustup toolchain install stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: false # set in linux-test shared-key: "amd-ci" @@ -349,7 +349,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -358,23 +358,10 @@ jobs: with: rust-version: stable - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci-linux-test-example" - - name: Remove unnecessary preinstalled software - run: | - echo "Disk space before cleanup:" - df -h - apt-get clean - rm -rf /__t/CodeQL - rm -rf /__t/PyPy - rm -rf /__t/Java_Temurin-Hotspot_jdk - rm -rf /__t/Python - rm -rf /__t/go - rm -rf /__t/Ruby - echo "Disk space after cleanup:" - df -h - name: Run examples run: | # test datafusion-sql examples @@ -392,7 +379,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -413,7 +400,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -425,7 +412,7 @@ jobs: name: build and run with wasm-pack runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup for wasm32 run: | rustup target add wasm32-unknown-unknown @@ -434,7 +421,7 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - uses: taiki-e/install-action@6f9c7cc51aa54b13cbcbd12f8bbf69d8ba405b4b # v2.62.47 + uses: taiki-e/install-action@de7896b7cd1c7d181266425abbe571b5a8c757bc # v2.65.3 with: tool: wasm-pack - name: Run tests with headless mode @@ -453,7 +440,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -500,7 +487,7 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -524,7 +511,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -562,7 +549,7 @@ jobs: name: cargo test (macos-aarch64) runs-on: macos-14 steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -572,37 +559,13 @@ jobs: shell: bash run: cargo test --profile ci --exclude datafusion-cli --workspace --lib --tests --bins --features avro,json,backtrace,integration-tests - test-datafusion-pyarrow: - name: cargo test pyarrow (amd64) - needs: linux-build-lib - runs-on: ubuntu-latest - container: - image: amd64/rust:bullseye # Use the bullseye tag image which comes with python3.9 - steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - submodules: true - fetch-depth: 1 - - name: Install PyArrow - run: | - echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV - apt-get update - apt-get install python3-pip -y - python3 -m pip install pyarrow - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run datafusion-common tests - run: cargo test --profile ci -p datafusion-common --features=pyarrow,sql - vendor: name: Verify Vendored Code runs-on: ubuntu-latest container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -619,7 +582,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -678,7 +641,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -689,7 +652,7 @@ jobs: - name: Install Clippy run: rustup component add clippy - name: Rust Dependency Cache - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 with: save-if: ${{ github.ref_name == 'main' }} shared-key: "amd-ci-clippy" @@ -703,7 +666,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -724,7 +687,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: submodules: true fetch-depth: 1 @@ -732,7 +695,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 + - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -746,6 +709,23 @@ jobs: ./dev/update_function_docs.sh git diff --exit-code + examples-docs-check: + name: check example README is up-to-date + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + submodules: true + fetch-depth: 1 + + - name: Run examples docs check script + run: | + bash ci/scripts/check_examples_docs.sh + # Verify MSRV for the crates which are directly used by other projects: # - datafusion # - datafusion-substrait @@ -757,11 +737,11 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - uses: taiki-e/install-action@6f9c7cc51aa54b13cbcbd12f8bbf69d8ba405b4b # v2.62.47 + uses: taiki-e/install-action@de7896b7cd1c7d181266425abbe571b5a8c757bc # v2.65.3 with: tool: cargo-msrv @@ -798,12 +778,4 @@ jobs: run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-proto working-directory: datafusion/proto - run: cargo msrv --output-format json --log-target stdout verify - typos: - name: Spell Check with Typos - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - persist-credentials: false - - uses: crate-ci/typos@07d900b8fa1097806b8adb6391b0d3e0ac2fdea7 # v1.39.0 + run: cargo msrv --output-format json --log-target stdout verify \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index d5fc9287aa6a5..2aba1085b8329 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,7 +27,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 + - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index 86dc190add1d1..ffb5f728e04c1 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -name: Assign the issue via a `take` comment +name: Assign/unassign the issue via `take` or `untake` comment on: issue_comment: types: created @@ -26,16 +26,30 @@ permissions: jobs: issue_assign: runs-on: ubuntu-latest - if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' + if: (!github.event.issue.pull_request) && (github.event.comment.body == 'take' || github.event.comment.body == 'untake') concurrency: group: ${{ github.actor }}-issue-assign steps: - - run: | - CODE=$(curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -LI https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees/${{ github.event.comment.user.login }} -o /dev/null -w '%{http_code}\n' -s) - if [ "$CODE" -eq "204" ] + - name: Take or untake issue + env: + COMMENT_BODY: ${{ github.event.comment.body }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + USER_LOGIN: ${{ github.event.comment.user.login }} + REPO: ${{ github.repository }} + TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if [ "$COMMENT_BODY" == "take" ] then - echo "Assigning issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" - curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -d '{"assignees": ["${{ github.event.comment.user.login }}"]}' https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees - else - echo "Cannot assign issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" + CODE=$(curl -H "Authorization: token $TOKEN" -LI https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees/$USER_LOGIN -o /dev/null -w '%{http_code}\n' -s) + if [ "$CODE" -eq "204" ] + then + echo "Assigning issue $ISSUE_NUMBER to $USER_LOGIN" + curl -X POST -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees + else + echo "Cannot assign issue $ISSUE_NUMBER to $USER_LOGIN" + fi + elif [ "$COMMENT_BODY" == "untake" ] + then + echo "Unassigning issue $ISSUE_NUMBER from $USER_LOGIN" + curl -X DELETE -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees fi \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f500265108ff5..2ce60805c913c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,6 +105,15 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -184,15 +193,16 @@ checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "apache-avro" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a033b4ced7c585199fb78ef50fca7fe2f444369ec48080c5fd072efa1a03cc7" +checksum = "36fa98bc79671c7981272d91a8753a928ff6a1cd8e4f20a44c45bd5d313840bf" dependencies = [ "bigdecimal", "bon", - "bzip2 0.6.1", + "bzip2", "crc32fast", "digest", + "liblzma", "log", "miniz_oxide", "num-bigint", @@ -207,7 +217,6 @@ dependencies = [ "strum_macros 0.27.2", "thiserror", "uuid", - "xz2", "zstd", ] @@ -225,9 +234,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df8bb5b0bd64c0b9bc61317fcc480bad0f00e56d3bc32c69a4c8dada4786bae" +checksum = "cb372a7cbcac02a35d3fb7b3fc1f969ec078e871f9bb899bf00a2e1809bec8a3" dependencies = [ "arrow-arith", "arrow-array", @@ -238,7 +247,6 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-ord", - "arrow-pyarrow", "arrow-row", "arrow-schema", "arrow-select", @@ -249,9 +257,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a640186d3bd30a24cb42264c2dafb30e236a6f50d510e56d40b708c9582491" +checksum = "0f377dcd19e440174596d83deb49cd724886d91060c07fec4f67014ef9d54049" dependencies = [ "arrow-array", "arrow-buffer", @@ -263,9 +271,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219fe420e6800979744c8393b687afb0252b3f8a89b91027d27887b72aa36d31" +checksum = "a23eaff85a44e9fa914660fb0d0bb00b79c4a3d888b5334adb3ea4330c84f002" dependencies = [ "ahash 0.8.12", "arrow-buffer", @@ -274,7 +282,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "num-complex", "num-integer", "num-traits", @@ -282,9 +290,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76885a2697a7edf6b59577f568b456afc94ce0e2edc15b784ce3685b6c3c5c27" +checksum = "a2819d893750cb3380ab31ebdc8c68874dd4429f90fd09180f3c93538bd21626" dependencies = [ "bytes", "half", @@ -294,13 +302,14 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9ebb4c987e6b3b236fb4a14b20b34835abfdd80acead3ccf1f9bf399e1f168" +checksum = "e3d131abb183f80c450d4591dc784f8d7750c50c6e2bc3fcaad148afc8361271" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", + "arrow-ord", "arrow-schema", "arrow-select", "atoi", @@ -315,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92386159c8d4bce96f8bd396b0642a0d544d471bdc2ef34d631aec80db40a09c" +checksum = "2275877a0e5e7e7c76954669366c2aa1a829e340ab1f612e647507860906fb6b" dependencies = [ "arrow-array", "arrow-cast", @@ -330,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727681b95de313b600eddc2a37e736dcb21980a40f640314dcf360e2f36bc89b" +checksum = "05738f3d42cb922b9096f7786f606fcb8669260c2640df8490533bb2fa38c9d3" dependencies = [ "arrow-buffer", "arrow-schema", @@ -343,9 +352,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f70bb56412a007b0cfc116d15f24dda6adeed9611a213852a004cda20085a3b9" +checksum = "8b5f57c3d39d1b1b7c1376a772ea86a131e7da310aed54ebea9363124bb885e3" dependencies = [ "arrow-arith", "arrow-array", @@ -371,9 +380,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9ba92e3de170295c98a84e5af22e2b037f0c7b32449445e6c493b5fca27f27" +checksum = "3d09446e8076c4b3f235603d9ea7c5494e73d441b01cd61fb33d7254c11964b3" dependencies = [ "arrow-array", "arrow-buffer", @@ -387,9 +396,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b969b4a421ae83828591c6bf5450bd52e6d489584142845ad6a861f42fe35df8" +checksum = "371ffd66fa77f71d7628c63f209c9ca5341081051aa32f9c8020feb0def787c0" dependencies = [ "arrow-array", "arrow-buffer", @@ -398,7 +407,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "lexical-core", "memchr", @@ -411,9 +420,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "141c05298b21d03e88062317a1f1a73f5ba7b6eb041b350015b1cd6aabc0519b" +checksum = "cbc94fc7adec5d1ba9e8cd1b1e8d6f72423b33fe978bf1f46d970fafab787521" dependencies = [ "arrow-array", "arrow-buffer", @@ -422,23 +431,11 @@ dependencies = [ "arrow-select", ] -[[package]] -name = "arrow-pyarrow" -version = "57.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcfb2be2e9096236f449c11f425cddde18c4cc540f516d90f066f10a29ed515" -dependencies = [ - "arrow-array", - "arrow-data", - "arrow-schema", - "pyo3", -] - [[package]] name = "arrow-row" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f3c06a6abad6164508ed283c7a02151515cef3de4b4ff2cebbcaeb85533db2" +checksum = "169676f317157dc079cc5def6354d16db63d8861d61046d2f3883268ced6f99f" dependencies = [ "arrow-array", "arrow-buffer", @@ -449,9 +446,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cfa7a03d1eee2a4d061476e1840ad5c9867a544ca6c4c59256496af5d0a8be5" +checksum = "d27609cd7dd45f006abae27995c2729ef6f4b9361cde1ddd019dc31a5aa017e0" dependencies = [ "bitflags 2.9.4", "serde", @@ -461,9 +458,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bafa595babaad59f2455f4957d0f26448fb472722c186739f4fac0823a1bdb47" +checksum = "ae980d021879ea119dd6e2a13912d81e64abed372d53163e804dfe84639d8010" dependencies = [ "ahash 0.8.12", "arrow-array", @@ -475,9 +472,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f46457dbbb99f2650ff3ac23e46a929e0ab81db809b02aa5511c258348bef2" +checksum = "cf35e8ef49dcf0c5f6d175edee6b8af7b45611805333129c541a8b89a0fc0534" dependencies = [ "arrow-array", "arrow-buffer", @@ -520,19 +517,15 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.19" +version = "0.4.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06575e6a9673580f52661c92107baabffbf41e2141373441cbcdc47cb733003c" +checksum = "07a926debf178f2d355197f9caddb08e54a9329d44748034bba349c5848cb519" dependencies = [ - "bzip2 0.5.2", - "flate2", + "compression-codecs", + "compression-core", "futures-core", - "memchr", "pin-project-lite", "tokio", - "xz2", - "zstd", - "zstd-safe", ] [[package]] @@ -552,7 +545,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -574,7 +567,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -585,7 +578,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -611,9 +604,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.8.7" +version = "1.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04b37ddf8d2e9744a0b9c19ce0b78efe4795339a90b66b7bae77987092cd2e69" +checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5" dependencies = [ "aws-credential-types", "aws-runtime", @@ -641,9 +634,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.7" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799a1290207254984cb7c05245111bc77958b92a3c9bb449598044b36341cce6" +checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -676,9 +669,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.11" +version = "1.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e1ed337dabcf765ad5f2fb426f13af22d576328aaf09eac8f70953530798ec0" +checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -700,9 +693,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.85.0" +version = "1.91.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f2c741e2e439f07b5d1b33155e246742353d82167c785a2ff547275b7e32483" +checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d" dependencies = [ "aws-credential-types", "aws-runtime", @@ -722,9 +715,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.87.0" +version = "1.93.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6428ae5686b18c0ee99f6f3c39d94ae3f8b42894cdc35c35d8fb2470e9db2d4c" +checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e" dependencies = [ "aws-credential-types", "aws-runtime", @@ -744,9 +737,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.87.0" +version = "1.95.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5871bec9a79a3e8d928c7788d654f135dde0e71d2dd98089388bab36b37ef607" +checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164" dependencies = [ "aws-credential-types", "aws-runtime", @@ -767,9 +760,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.3.4" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084c34162187d39e3740cb635acd73c4e3a551a36146ad6fe8883c929c9f876c" +checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -789,9 +782,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.5" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" +checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" dependencies = [ "futures-util", "pin-project-lite", @@ -800,15 +793,16 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.3" +version = "0.62.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c4dacf2d38996cf729f55e7a762b30918229917eca115de45dfa8dfb97796c9" +checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", + "futures-util", "http 0.2.12", "http 1.3.1", "http-body 0.4.6", @@ -820,9 +814,9 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.1.1" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147e8eea63a40315d704b97bf9bc9b8c1402ae94f89d5ad6f7550d963309da1b" +checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -844,27 +838,27 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.5" +version = "0.61.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaa31b350998e703e9826b2104dd6f63be0508666e1aba88137af060e8944047" +checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.3" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" dependencies = [ "aws-smithy-runtime-api", ] [[package]] name = "aws-smithy-query" -version = "0.60.7" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" dependencies = [ "aws-smithy-types", "urlencoding", @@ -872,9 +866,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.9.2" +version = "1.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fa63ad37685ceb7762fa4d73d06f1d5493feb88e3f27259b9ed277f4c01b185" +checksum = "65fda37911905ea4d3141a01364bc5509a0f32ae3f3b22d6e330c0abfb62d247" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -896,9 +890,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.9.0" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07f5e0fc8a6b3f2303f331b94504bbf754d85488f402d6f1dd7a6080f99afe56" +checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -913,9 +907,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.2" +version = "1.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" +checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01" dependencies = [ "base64-simd", "bytes", @@ -936,18 +930,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.10" +version = "0.60.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" +checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.8" +version = "1.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b069d19bf01e46298eaedd7c6f283fe565a59263e53eebec945f3e6398f42390" +checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -1026,9 +1020,9 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" +checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934" dependencies = [ "autocfg", "libm", @@ -1055,7 +1049,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1192,9 +1186,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.7.2" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2529c31017402be841eb45892278a6c21a000c0a17643af326c73a73f83f0fb" +checksum = "ebeb9aaf9329dff6ceb65c689ca3db33dbf15f324909c60e4e5eef5701ce31b1" dependencies = [ "bon-macros", "rustversion", @@ -1202,9 +1196,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.7.2" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82020dadcb845a345591863adb65d74fa8dc5c18a0b6d408470e13b7adc7005" +checksum = "77e9d642a7e3a318e37c2c9427b5a6a48aa1ad55dcd986f3034ab2239045a645" dependencies = [ "darling", "ident_case", @@ -1212,7 +1206,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1235,7 +1229,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1305,9 +1299,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "bytes-utils" @@ -1319,15 +1313,6 @@ dependencies = [ "either", ] -[[package]] -name = "bzip2" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" -dependencies = [ - "bzip2-sys", -] - [[package]] name = "bzip2" version = "0.6.1" @@ -1337,16 +1322,6 @@ dependencies = [ "libbz2-rs-sys", ] -[[package]] -name = "bzip2-sys" -version = "0.1.13+1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" -dependencies = [ - "cc", - "pkg-config", -] - [[package]] name = "cast" version = "0.3.0" @@ -1461,9 +1436,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.50" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2cfd7bf8a6017ddaa4e32ffe7403d547790db06bd171c1c53926faab501623" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", "clap_derive", @@ -1471,9 +1446,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.50" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4c05b9e80c5ccd3a7ef080ad7b6ba7d6fc00a985b8b157197075677c82c7a0" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstream", "anstyle", @@ -1490,7 +1465,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1534,6 +1509,27 @@ dependencies = [ "unicode-width 0.2.1", ] +[[package]] +name = "compression-codecs" +version = "0.4.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34a3cbbb8b6eca96f3a5c4bf6938d5b27ced3675d69f95bb51948722870bc323" +dependencies = [ + "bzip2", + "compression-core", + "flate2", + "liblzma", + "memchr", + "zstd", + "zstd-safe", +] + +[[package]] +name = "compression-core" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" + [[package]] name = "console" version = "0.15.11" @@ -1655,19 +1651,21 @@ dependencies = [ [[package]] name = "criterion" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" dependencies = [ + "alloca", "anes", "cast", "ciborium", - "clap 4.5.50", + "clap 4.5.53", "criterion-plot", "futures", "itertools 0.13.0", "num-traits", "oorandom", + "page_size", "plotters", "rayon", "regex", @@ -1680,9 +1678,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.6.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" dependencies = [ "cast", "itertools 0.13.0", @@ -1761,9 +1759,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.6.1" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ffc71fcdcdb40d6f087edddf7f8f1f8f79e6cf922f555a9ee8779752d4819bd" +checksum = "424e0138278faeb2b401f174ad17e715c829512d74f3d1e81eb43365c2e0590e" dependencies = [ "ctor-proc-macro", "dtor", @@ -1802,7 +1800,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1813,7 +1811,7 @@ checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ "darling_core", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -1832,13 +1830,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "arrow-schema", "async-trait", "bytes", - "bzip2 0.6.1", + "bzip2", "chrono", "criterion", "ctor", @@ -1879,6 +1877,7 @@ dependencies = [ "glob", "insta", "itertools 0.14.0", + "liblzma", "log", "nix", "object_store", @@ -1898,13 +1897,12 @@ dependencies = [ "tokio", "url", "uuid", - "xz2", "zstd", ] [[package]] name = "datafusion-benchmarks" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion", @@ -1929,7 +1927,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -1952,7 +1950,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -1970,19 +1968,18 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "tokio", ] [[package]] name = "datafusion-cli" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", "aws-config", "aws-credential-types", "chrono", - "clap 4.5.50", + "clap 4.5.53", "ctor", "datafusion", "datafusion-common", @@ -2007,24 +2004,24 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "apache-avro", "arrow", "arrow-ipc", "chrono", + "criterion", "half", "hashbrown 0.14.5", "hex", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "libc", "log", "object_store", "parquet", "paste", - "pyo3", "rand 0.9.2", "recursive", "sqlparser", @@ -2034,7 +2031,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "50.3.0" +version = "51.0.0" dependencies = [ "futures", "log", @@ -2043,13 +2040,13 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.6.1", + "bzip2", "chrono", "criterion", "datafusion-common", @@ -2065,6 +2062,7 @@ dependencies = [ "futures", "glob", "itertools 0.14.0", + "liblzma", "log", "object_store", "rand 0.9.2", @@ -2072,13 +2070,12 @@ dependencies = [ "tokio", "tokio-util", "url", - "xz2", "zstd", ] [[package]] name = "datafusion-datasource-arrow" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "arrow-ipc", @@ -2101,7 +2098,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "50.3.0" +version = "51.0.0" dependencies = [ "apache-avro", "arrow", @@ -2120,7 +2117,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2141,7 +2138,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2161,7 +2158,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2190,11 +2187,11 @@ dependencies = [ [[package]] name = "datafusion-doc" -version = "50.3.0" +version = "51.0.0" [[package]] name = "datafusion-examples" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "arrow-flight", @@ -2204,11 +2201,14 @@ dependencies = [ "bytes", "dashmap", "datafusion", - "datafusion-ffi", + "datafusion-common", + "datafusion-expr", "datafusion-physical-expr-adapter", "datafusion-proto", + "datafusion-sql", "env_logger", "futures", + "insta", "log", "mimalloc", "nix", @@ -2216,6 +2216,8 @@ dependencies = [ "prost", "rand 0.9.2", "serde_json", + "strum 0.27.2", + "strum_macros 0.27.2", "tempfile", "test-utils", "tokio", @@ -2228,7 +2230,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2249,7 +2251,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2262,7 +2264,7 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "itertools 0.14.0", "paste", @@ -2273,18 +2275,18 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.12.0", + "indexmap 2.12.1", "itertools 0.14.0", "paste", ] [[package]] name = "datafusion-ffi" -version = "50.3.0" +version = "51.0.0" dependencies = [ "abi_stable", "arrow", @@ -2292,10 +2294,22 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-catalog", "datafusion-common", + "datafusion-datasource", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", "datafusion-proto", "datafusion-proto-common", + "datafusion-session", "doc-comment", "futures", "log", @@ -2306,7 +2320,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "arrow-buffer", @@ -2314,6 +2328,7 @@ dependencies = [ "blake2", "blake3", "chrono", + "chrono-tz", "criterion", "ctor", "datafusion-common", @@ -2338,7 +2353,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2359,7 +2374,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2372,7 +2387,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "arrow-ord", @@ -2395,7 +2410,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2409,7 +2424,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2425,7 +2440,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2433,16 +2448,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "50.3.0" +version = "51.0.0" dependencies = [ "datafusion-doc", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] name = "datafusion-optimizer" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", @@ -2458,7 +2473,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "itertools 0.14.0", "log", @@ -2469,7 +2484,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "arrow", @@ -2482,19 +2497,21 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "itertools 0.14.0", "parking_lot", "paste", "petgraph 0.8.3", "rand 0.9.2", + "recursive", "rstest", + "tokio", ] [[package]] name = "datafusion-physical-expr-adapter" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2507,19 +2524,21 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "arrow", + "chrono", "datafusion-common", "datafusion-expr-common", "hashbrown 0.14.5", "itertools 0.14.0", + "parking_lot", ] [[package]] name = "datafusion-physical-optimizer" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2539,19 +2558,19 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "50.3.0" +version = "51.0.0" dependencies = [ "ahash 0.8.12", "arrow", "arrow-ord", "arrow-schema", "async-trait", - "chrono", "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-functions-window", @@ -2561,7 +2580,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "itertools 0.14.0", "log", @@ -2575,9 +2594,10 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "datafusion", "datafusion-catalog", @@ -2611,7 +2631,7 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2623,7 +2643,7 @@ dependencies = [ [[package]] name = "datafusion-pruning" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2641,7 +2661,7 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "50.3.0" +version = "51.0.0" dependencies = [ "async-trait", "datafusion-common", @@ -2653,7 +2673,7 @@ dependencies = [ [[package]] name = "datafusion-spark" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "bigdecimal", @@ -2665,7 +2685,9 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-nested", "log", + "percent-encoding", "rand 0.9.2", "sha1", "url", @@ -2673,7 +2695,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "bigdecimal", @@ -2686,7 +2708,7 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.12.0", + "indexmap 2.12.1", "insta", "itertools 0.14.0", "log", @@ -2699,14 +2721,14 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "50.3.0" +version = "51.0.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.50", + "clap 4.5.53", "datafusion", "datafusion-spark", "datafusion-substrait", @@ -2733,7 +2755,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "50.3.0" +version = "51.0.0" dependencies = [ "async-recursion", "async-trait", @@ -2755,7 +2777,7 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "50.3.0" +version = "51.0.0" dependencies = [ "chrono", "console_error_panic_hook", @@ -2830,7 +2852,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -2886,7 +2908,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -2924,7 +2946,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -3078,9 +3100,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" dependencies = [ "crc32fast", "libz-rs-sys", @@ -3185,7 +3207,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -3317,7 +3339,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.3.1", - "indexmap 2.12.0", + "indexmap 2.12.1", "slab", "tokio", "tokio-util", @@ -3368,9 +3390,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "heck" @@ -3604,7 +3626,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.0", + "windows-core", ] [[package]] @@ -3742,21 +3764,21 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "serde", "serde_core", ] [[package]] name = "indicatif" -version = "0.18.0" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a646d946d06bedbbc4cac4c218acf4bbf2d87757a784857025f4d447e4e1cd" +checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" dependencies = [ "console 0.16.1", "portable-atomic", @@ -3765,17 +3787,11 @@ dependencies = [ "web-time", ] -[[package]] -name = "indoc" -version = "2.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" - [[package]] name = "insta" -version = "1.43.2" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fdb647ebde000f43b5b53f773c30cf9b0cb4300453208713fa38b2c70935a0" +checksum = "b76866be74d68b1595eb8060cb9191dca9c021db2316558e52ddc5d55d41b66c" dependencies = [ "console 0.15.11", "globset", @@ -3783,6 +3799,7 @@ dependencies = [ "regex", "serde", "similar", + "tempfile", "walkdir", ] @@ -3870,7 +3887,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -3885,9 +3902,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -3988,6 +4005,26 @@ dependencies = [ "windows-link 0.2.0", ] +[[package]] +name = "liblzma" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c36d08cad03a3fbe2c4e7bb3a9e84c57e4ee4135ed0b065cade3d98480c648" +dependencies = [ + "liblzma-sys", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b9596486f6d60c3bbe644c0e1be1aa6ccc472ad630fe8927b456973d7cb736" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "libm" version = "0.2.15" @@ -4024,7 +4061,7 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.50", + "clap 4.5.53", "escape8259", ] @@ -4061,9 +4098,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -4073,24 +4110,13 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "lz4_flex" -version = "0.11.5" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" +checksum = "ab6473172471198271ff72e9379150e9dfd70d8e533e0752a27e515b48dd375e" dependencies = [ "twox-hash", ] -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "matchit" version = "0.8.4" @@ -4113,15 +4139,6 @@ version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - [[package]] name = "mimalloc" version = "0.1.48" @@ -4417,6 +4434,16 @@ version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48dd4f4a2c8405440fd0462561f0e5806bd0f77e86f51c761481bdd4018b545e" +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking_lot" version = "0.12.4" @@ -4442,9 +4469,9 @@ dependencies = [ [[package]] name = "parquet" -version = "57.0.0" +version = "57.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a0f31027ef1af7549f7cec603a9a21dce706d3f8d7c2060a68f43c1773be95a" +checksum = "be3e4f6d320dd92bfa7d612e265d7d08bba0a240bab86af3425e1d255a511d89" dependencies = [ "ahash 0.8.12", "arrow-array", @@ -4461,7 +4488,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "lz4_flex", "num-bigint", "num-integer", @@ -4500,7 +4527,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -4559,7 +4586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.12.0", + "indexmap 2.12.1", ] [[package]] @@ -4570,7 +4597,7 @@ checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.12.0", + "indexmap 2.12.1", "serde", ] @@ -4628,7 +4655,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -4701,7 +4728,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -4776,7 +4803,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -4847,7 +4874,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn 2.0.108", + "syn 2.0.111", "tempfile", ] @@ -4861,7 +4888,7 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -4911,67 +4938,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "pyo3" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" -dependencies = [ - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" -dependencies = [ - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.108", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.108", -] - [[package]] name = "quad-rand" version = "0.2.3" @@ -5180,7 +5146,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5220,7 +5186,7 @@ checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5248,9 +5214,9 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943f41321c63ef1c92fd763bfe054d2668f7f225a5c29f0105903dc2fc04ba30" +checksum = "8d942b98df5e658f56f20d592c7f868833fe38115e65c33003d8cd224b0155da" [[package]] name = "regex-syntax" @@ -5402,7 +5368,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.108", + "syn 2.0.111", "unicode-ident", ] @@ -5414,7 +5380,7 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5618,7 +5584,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5709,7 +5675,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5720,7 +5686,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5744,7 +5710,7 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5756,7 +5722,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5781,7 +5747,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.12.0", + "indexmap 2.12.1", "schemars 0.9.0", "schemars 1.0.4", "serde", @@ -5800,7 +5766,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -5809,7 +5775,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "ryu", "serde", @@ -5976,7 +5942,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6024,7 +5990,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6035,7 +6001,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6084,7 +6050,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6096,7 +6062,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6130,7 +6096,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.108", + "syn 2.0.111", "typify", "walkdir", ] @@ -6154,9 +6120,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.108" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da58917d35242480a05c2897064da0a80589a2a0476c9a3f2fdc83b53502e917" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -6180,7 +6146,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6203,12 +6169,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" -[[package]] -name = "target-lexicon" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" - [[package]] name = "tempfile" version = "3.23.0" @@ -6297,7 +6257,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6420,7 +6380,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -6472,9 +6432,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.16" +version = "0.7.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" dependencies = [ "bytes", "futures-core", @@ -6498,7 +6458,7 @@ version = "0.23.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3effe7c0e86fdff4f69cdd2ccc1b96f933e24811c5441d44904e8683e27184b" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "toml_datetime", "toml_parser", "winnow", @@ -6561,7 +6521,7 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.0", + "indexmap 2.12.1", "pin-project-lite", "slab", "sync_wrapper", @@ -6621,14 +6581,14 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -6647,9 +6607,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -6729,7 +6689,7 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.108", + "syn 2.0.111", "thiserror", "unicode-ident", ] @@ -6747,7 +6707,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.108", + "syn 2.0.111", "typify-impl", ] @@ -6806,12 +6766,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" -[[package]] -name = "unindent" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" - [[package]] name = "unit-prefix" version = "0.5.1" @@ -6897,13 +6851,13 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.18.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.4", "js-sys", - "serde", + "serde_core", "wasm-bindgen", ] @@ -6967,9 +6921,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -6980,9 +6934,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -6993,9 +6947,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7003,34 +6957,42 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfc379bfb624eb59050b509c13e77b4eb53150c350db69628141abce842f2373" +checksum = "25e90e66d265d3a1efc0e72a54809ab90b9c0c515915c67cdf658689d2c22c6c" dependencies = [ + "async-trait", + "cast", "js-sys", + "libm", "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", @@ -7038,13 +7000,13 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "085b2df989e1e6f9620c1311df6c996e83fe16f57792b272ce1e024ac16a90f1" +checksum = "7150335716dce6028bead2b848e72f47b45e7b9422f64cccdc23bedca89affc1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -7062,9 +7024,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", @@ -7138,7 +7100,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" dependencies = [ "windows-collections", - "windows-core 0.61.2", + "windows-core", "windows-future", "windows-link 0.1.3", "windows-numerics", @@ -7150,7 +7112,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" dependencies = [ - "windows-core 0.61.2", + "windows-core", ] [[package]] @@ -7162,21 +7124,8 @@ dependencies = [ "windows-implement", "windows-interface", "windows-link 0.1.3", - "windows-result 0.3.4", - "windows-strings 0.4.2", -] - -[[package]] -name = "windows-core" -version = "0.62.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57fe7168f7de578d2d8a05b07fd61870d2e73b4020e9f49aa00da8471723497c" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link 0.2.0", - "windows-result 0.4.0", - "windows-strings 0.5.0", + "windows-result", + "windows-strings", ] [[package]] @@ -7185,7 +7134,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" dependencies = [ - "windows-core 0.61.2", + "windows-core", "windows-link 0.1.3", "windows-threading", ] @@ -7198,7 +7147,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -7209,7 +7158,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -7230,7 +7179,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ - "windows-core 0.61.2", + "windows-core", "windows-link 0.1.3", ] @@ -7243,15 +7192,6 @@ dependencies = [ "windows-link 0.1.3", ] -[[package]] -name = "windows-result" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f" -dependencies = [ - "windows-link 0.2.0", -] - [[package]] name = "windows-strings" version = "0.4.2" @@ -7261,15 +7201,6 @@ dependencies = [ "windows-link 0.1.3", ] -[[package]] -name = "windows-strings" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda" -dependencies = [ - "windows-link 0.2.0", -] - [[package]] name = "windows-sys" version = "0.52.0" @@ -7490,15 +7421,6 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" -[[package]] -name = "xz2" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] - [[package]] name = "yansi" version = "1.0.1" @@ -7525,7 +7447,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", "synstructure", ] @@ -7546,7 +7468,7 @@ checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] @@ -7566,7 +7488,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", "synstructure", ] @@ -7606,7 +7528,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.108", + "syn 2.0.111", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index f15929b4c2b00..10fc88b7057c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ resolver = "2" [workspace.package] authors = ["Apache DataFusion "] -edition = "2021" +edition = "2024" homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" @@ -79,7 +79,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.88.0" # Define DataFusion version -version = "50.3.0" +version = "51.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -90,83 +90,88 @@ version = "50.3.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -apache-avro = { version = "0.20", default-features = false } -arrow = { version = "57.0.0", features = [ +apache-avro = { version = "0.21", default-features = false } +arrow = { version = "57.1.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "57.0.0", default-features = false } -arrow-flight = { version = "57.0.0", features = [ +arrow-buffer = { version = "57.1.0", default-features = false } +arrow-flight = { version = "57.1.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "57.0.0", default-features = false, features = [ +arrow-ipc = { version = "57.1.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "57.0.0", default-features = false } -arrow-schema = { version = "57.0.0", default-features = false } +arrow-ord = { version = "57.1.0", default-features = false } +arrow-schema = { version = "57.1.0", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" -bytes = "1.10" +bytes = "1.11" +bzip2 = "0.6.1" chrono = { version = "0.4.42", default-features = false } -criterion = "0.7" -ctor = "0.6.1" +criterion = "0.8" +ctor = "0.6.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "50.3.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "50.3.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "50.3.0" } -datafusion-common = { path = "datafusion/common", version = "50.3.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "50.3.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "50.3.0", default-features = false } -datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "50.3.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "50.3.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "50.3.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "50.3.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "50.3.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "50.3.0" } -datafusion-execution = { path = "datafusion/execution", version = "50.3.0", default-features = false } -datafusion-expr = { path = "datafusion/expr", version = "50.3.0", default-features = false } -datafusion-expr-common = { path = "datafusion/expr-common", version = "50.3.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "50.3.0" } -datafusion-functions = { path = "datafusion/functions", version = "50.3.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "50.3.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "50.3.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "50.3.0", default-features = false } -datafusion-functions-table = { path = "datafusion/functions-table", version = "50.3.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "50.3.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "50.3.0" } -datafusion-macros = { path = "datafusion/macros", version = "50.3.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "50.3.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "50.3.0", default-features = false } -datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "50.3.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "50.3.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "50.3.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "50.3.0" } -datafusion-proto = { path = "datafusion/proto", version = "50.3.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "50.3.0" } -datafusion-pruning = { path = "datafusion/pruning", version = "50.3.0" } -datafusion-session = { path = "datafusion/session", version = "50.3.0" } -datafusion-spark = { path = "datafusion/spark", version = "50.3.0" } -datafusion-sql = { path = "datafusion/sql", version = "50.3.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "50.3.0" } +datafusion = { path = "datafusion/core", version = "51.0.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "51.0.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "51.0.0" } +datafusion-common = { path = "datafusion/common", version = "51.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "51.0.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "51.0.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "51.0.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "51.0.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "51.0.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "51.0.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "51.0.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "51.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "51.0.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "51.0.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "51.0.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "51.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "51.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "51.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "51.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "51.0.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "51.0.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "51.0.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "51.0.0" } +datafusion-macros = { path = "datafusion/macros", version = "51.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "51.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "51.0.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "51.0.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "51.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "51.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "51.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "51.0.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "51.0.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "51.0.0" } +datafusion-session = { path = "datafusion/session", version = "51.0.0" } +datafusion-spark = { path = "datafusion/spark", version = "51.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "51.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "51.0.0" } doc-comment = "0.3" env_logger = "0.11" +flate2 = "1.1.5" futures = "0.3" +glob = "0.3.0" half = { version = "2.7.0", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } hex = { version = "0.4.3" } -indexmap = "2.12.0" -insta = { version = "1.43.2", features = ["glob", "filters"] } +indexmap = "2.12.1" +insta = { version = "1.45.0", features = ["glob", "filters"] } itertools = "0.14" +liblzma = { version = "0.4.4", features = ["static"] } log = "^0.4" num-traits = { version = "0.2" } object_store = { version = "0.12.4", default-features = false } parking_lot = "0.12" -parquet = { version = "57.0.0", default-features = false, features = [ +parquet = { version = "57.1.0", default-features = false, features = [ "arrow", "async", "object_store", ] } +paste = "1.0.15" pbjson = { version = "0.8.0" } pbjson-types = "0.8" # Should match arrow-flight's version of prost. @@ -177,11 +182,14 @@ regex = "1.12" rstest = "0.26.1" serde_json = "1" sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } +strum = "0.27.2" +strum_macros = "0.27.2" tempfile = "3" testcontainers = { version = "0.25.2", features = ["default"] } testcontainers-modules = { version = "0.13" } tokio = { version = "1.48", features = ["macros", "rt", "sync"] } url = "2.5.7" +zstd = { version = "0.13", default-features = false } [workspace.lints.clippy] # Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) @@ -191,6 +199,8 @@ or_fun_call = "warn" unnecessary_lazy_evaluations = "warn" uninlined_format_args = "warn" inefficient_to_string = "warn" +# https://github.com/apache/datafusion/issues/18503 +needless_pass_by_value = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = [ diff --git a/README.md b/README.md index 5191496eaafe3..880adfb3ac392 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ [![Build Status][actions-badge]][actions-url] ![Commit Activity][commit-activity-badge] [![Open Issues][open-issues-badge]][open-issues-url] +[![Pending PRs][pending-pr-badge]][pending-pr-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] ![Crates.io MSRV][msrv-badge] @@ -39,6 +40,8 @@ [commit-activity-badge]: https://img.shields.io/github/commit-activity/m/apache/datafusion [open-issues-badge]: https://img.shields.io/github/issues-raw/apache/datafusion [open-issues-url]: https://github.com/apache/datafusion/issues +[pending-pr-badge]: https://img.shields.io/github/issues-search/apache/datafusion?query=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess&label=Pending%20PRs&logo=github +[pending-pr-url]: https://github.com/apache/datafusion/pulls?q=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess+sort%3Aupdated-desc [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ [msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version @@ -129,7 +132,6 @@ Optional features: - `avro`: support for reading the [Apache Avro] format - `backtrace`: include backtrace information in error messages - `parquet_encryption`: support for using [Parquet Modular Encryption] -- `pyarrow`: conversions between PyArrow and DataFusion types - `serde`: enable arrow-schema's `serde` feature [apache avro]: https://avro.apache.org/ diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 870c826f55810..5f91175ca8baf 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -56,7 +56,7 @@ serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.16" } +tokio-util = { version = "0.7.17" } [dev-dependencies] datafusion-proto = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index 8fed85fa02b80..0b71628b2db12 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -119,7 +119,6 @@ You can also invoke the helper directly if you need to customise arguments furth ./benchmarks/compile_profile.py --profiles dev release --data /path/to/tpch_sf1 ``` - ## Benchmark with modified configurations ### Select join algorithm @@ -147,6 +146,19 @@ To verify that datafusion picked up your configuration, run the benchmarks with ## Comparing performance of main and a branch +For TPC-H +```shell +./benchmarks/compare_tpch.sh main mybranch +``` + +For TPC-DS. +To get data in `DATA_DIR` for TPCDS, please follow instructions in `./benchmarks/bench.sh data tcpds` +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/compare_tpcds.sh main mybranch +``` + +Alternatively you can compare manually followng the example velor + ```shell git checkout main @@ -243,28 +255,11 @@ See the help for more details. You can enable `mimalloc` or `snmalloc` (to use either the mimalloc or snmalloc allocator) as features by passing them in as `--features`. For example: ```shell -cargo run --release --features "mimalloc" --bin tpch -- benchmark datafusion --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 -``` - -The benchmark program also supports CSV and Parquet input file formats and a utility is provided to convert from `tbl` -(generated by the `dbgen` utility) to CSV and Parquet. - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-parquet --format parquet +cargo run --release --features "mimalloc" --bin dfbench tpch --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 ``` Or if you want to verify and run all the queries in the benchmark, you can just run `cargo test`. -#### Sorted Conversion - -The TPCH tables generated by the dbgen utility are sorted by their first column (their primary key for most tables, the `l_orderkey` column for the `lineitem` table.) - -To preserve this sorted order information during conversion (useful for benchmarking execution on pre-sorted data) include the `--sort` flag: - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-sorted-parquet --format parquet --sort -``` - ### Comparing results between runs Any `dfbench` execution with `-o ` argument will produce a @@ -316,7 +311,6 @@ This will produce output like: └──────────────┴──────────────┴──────────────┴───────────────┘ ``` - # Benchmark Runner The `dfbench` program contains subcommands to run the various @@ -356,24 +350,28 @@ FLAGS: ``` # Profiling Memory Stats for each benchmark query + The `mem_profile` program wraps benchmark execution to measure memory usage statistics, such as peak RSS. It runs each benchmark query in a separate subprocess, capturing the child process’s stdout to print structured output. Subcommands supported by mem_profile are the subset of those in `dfbench`. -Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch +Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch, TPCDS Before running benchmarks, `mem_profile` automatically compiles the benchmark binary (`dfbench`) using `cargo build`. Note that the build profile used for `dfbench` is not tied to the profile used for running `mem_profile` itself. We can explicitly specify the desired build profile using the `--bench-profile` option (e.g. release-nonlto). By prebuilding the binary and running each query in a separate process, we can ensure accurate memory statistics. Currently, `mem_profile` only supports `mimalloc` as the memory allocator, since it relies on `mimalloc`'s API to collect memory statistics. -Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. +Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. + +The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. -The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. +Example: -Example: ```shell datafusion$ cargo run --profile release-nonlto --bin mem_profile -- --bench-profile release-nonlto tpch --path benchmarks/data/tpch_sf1 --partitions 4 --format parquet ``` + Example Output: + ``` Query Time (ms) Peak RSS Peak Commit Major Page Faults ---------------------------------------------------------------- @@ -402,19 +400,21 @@ Query Time (ms) Peak RSS Peak Commit Major Page Faults ``` ## Reported Metrics + When running benchmarks, `mem_profile` collects several memory-related statistics using the mimalloc API: -- Peak RSS (Resident Set Size): -The maximum amount of physical memory used by the process. -This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. +- Peak RSS (Resident Set Size): + The maximum amount of physical memory used by the process. + This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. - Peak Commit: -The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). -This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. + The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). + This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. - Major Page Faults: -The number of major page faults triggered during execution. -This metric is obtained from the operating system and is not mimalloc-specific. + The number of major page faults triggered during execution. + This metric is obtained from the operating system and is not mimalloc-specific. + # Writing a new benchmark ## Creating or downloading data outside of the benchmark @@ -603,6 +603,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## TPCDS + +Run the tpcds benchmark. + +For data please clone `datafusion-benchmarks` repo which contains the predefined parquet data with SF1. + +```shell +git clone https://github.com/apache/datafusion-benchmarks +``` + +Then run the benchmark with the following command: + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds +``` + +Alternatively benchmark the specific query + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds 30 +``` + +More help + +```shell +cargo run --release --bin dfbench -- tpcds --help +``` + ## External Aggregation Run the benchmark for aggregations with limited memory. @@ -762,7 +790,7 @@ Different queries are included to test nested loop joins under various workloads ## Hash Join -This benchmark focuses on the performance of queries with nested hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. +This benchmark focuses on the performance of queries with hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. Several queries are included to test hash joins under various workloads. @@ -774,6 +802,19 @@ Several queries are included to test hash joins under various workloads. ./bench.sh run hj ``` +## Sort Merge Join + +This benchmark focuses on the performance of queries with sort merge joins joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test sort merge joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run smj +``` ## Cancellation Test performance of cancelling queries. @@ -804,3 +845,41 @@ Getting results... cancelling thread done dropping runtime in 83.531417ms ``` + +## Sorted Data Benchmarks + +### Data Sorted ClickBench + +Benchmark for queries on pre-sorted data to test sort order optimization. +This benchmark uses a subset of the ClickBench dataset (hits.parquet, ~14GB) that has been pre-sorted by the EventTime column. The queries are designed to test DataFusion's performance when the data is already sorted as is common in timeseries workloads. + +The benchmark includes queries that: +- Scan pre-sorted data with ORDER BY clauses that match the sort order +- Test reverse scans on sorted data +- Verify the performance result + +#### Generating Sorted Data + +The sorted dataset is automatically generated from the ClickBench partitioned dataset. You can configure the memory used during the sorting process with the `DATAFUSION_MEMORY_GB` environment variable. The default memory limit is 12GB. +```bash +./bench.sh data data_sorted_clickbench +``` + +To create the sorted dataset, for example with 16GB of memory, run: + +```bash +DATAFUSION_MEMORY_GB=16 ./bench.sh data data_sorted_clickbench +``` + +This command will: +1. Download the ClickBench partitioned dataset if not present +2. Sort hits.parquet by EventTime in ascending order +3. Save the sorted file as hits_sorted.parquet + +#### Running the Benchmark + +```bash +./bench.sh run data_sorted_clickbench +``` + +This runs queries against the pre-sorted dataset with the `--sorted-by EventTime` flag, which informs DataFusion that the data is pre-sorted, allowing it to optimize away redundant sort operations. diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index dbfd319dd9ad4..d5fa52d7f00ee 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -87,6 +87,9 @@ tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), tpch_csv10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single csv file per table, hash join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory +# TPC-DS Benchmarks +tpcds: TPCDS inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join + # Extended TPC-H Benchmarks sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=1) sort_tpch10: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=10) @@ -99,6 +102,9 @@ clickbench_partitioned: ClickBench queries against partitioned (100 files) parqu clickbench_pushdown: ClickBench queries against partitioned (100 files) parquet w/ filter_pushdown enabled clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +# Sorted Data Benchmarks (ORDER BY Optimization) +clickbench_sorted: ClickBench queries on pre-sorted data using prefer_existing_sort (tests sort elimination optimization) + # H2O.ai Benchmarks (Group By, Join, Window) h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv @@ -126,6 +132,7 @@ imdb: Join Order Benchmark (JOB) using the IMDB dataset conver cancellation: How long cancelling a query takes nlj: Benchmark for simple nested loop joins, testing various join scenarios hj: Benchmark for simple hash joins, testing various join scenarios +smj: Benchmark for simple sort merge joins, testing various join scenarios compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size @@ -189,8 +196,8 @@ main() { echo "***************************" case "$BENCHMARK" in all) - data_tpch "1" - data_tpch "10" + data_tpch "1" "parquet" + data_tpch "10" "parquet" data_h2o "SMALL" data_h2o "MEDIUM" data_h2o "BIG" @@ -203,18 +210,25 @@ main() { # nlj uses range() function, no data generation needed ;; tpch) - data_tpch "1" + data_tpch "1" "parquet" ;; tpch_mem) - # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" + ;; + tpch_csv) + data_tpch "1" "csv" ;; tpch10) - data_tpch "10" + data_tpch "10" "parquet" ;; tpch_mem10) - # same data as for tpch10 - data_tpch "10" + data_tpch "10" "parquet" + ;; + tpch_csv10) + data_tpch "10" "csv" + ;; + tpcds) + data_tpcds ;; clickbench_1) data_clickbench_1 @@ -289,19 +303,19 @@ main() { ;; external_aggr) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; sort_tpch) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; sort_tpch10) # same data as for tpch10 - data_tpch "10" + data_tpch "10" "parquet" ;; topk_tpch) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; nlj) # nlj uses range() function, no data generation needed @@ -311,8 +325,15 @@ main() { # hj uses range() function, no data generation needed echo "HJ benchmark does not require data generation" ;; + smj) + # smj uses range() function, no data generation needed + echo "SMJ benchmark does not require data generation" + ;; compile_profile) - data_tpch "1" + data_tpch "1" "parquet" + ;; + clickbench_sorted) + clickbench_sorted ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" @@ -384,6 +405,8 @@ main() { run_external_aggr run_nlj run_hj + run_tpcds + run_smj ;; tpch) run_tpch "1" "parquet" @@ -403,6 +426,9 @@ main() { tpch_mem10) run_tpch_mem "10" ;; + tpcds) + run_tpcds + ;; cancellation) run_cancellation ;; @@ -445,7 +471,7 @@ main() { h2o_medium_window) run_h2o_window "MEDIUM" "CSV" "window" ;; - h2o_big_window) + h2o_big_window) run_h2o_window "BIG" "CSV" "window" ;; h2o_small_parquet) @@ -494,9 +520,15 @@ main() { hj) run_hj ;; + smj) + run_smj + ;; compile_profile) run_compile_profile "${PROFILE_ARGS[@]}" ;; + clickbench_sorted) + run_clickbench_sorted + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -529,7 +561,7 @@ main() { # Creates TPCH data at a certain scale factor, if it doesn't already # exist # -# call like: data_tpch($scale_factor) +# call like: data_tpch($scale_factor, format) # # Creates data in $DATA_DIR/tpch_sf1 for scale factor 1 # Creates data in $DATA_DIR/tpch_sf10 for scale factor 10 @@ -540,20 +572,23 @@ data_tpch() { echo "Internal error: Scale factor not specified" exit 1 fi + FORMAT=$2 + if [ -z "$FORMAT" ] ; then + echo "Internal error: Format not specified" + exit 1 + fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - echo "Creating tpch dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." + echo "Creating tpch $FORMAT dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." # Ensure the target data directory exists mkdir -p "${TPCH_DIR}" - # Create 'tbl' (CSV format) data into $DATA_DIR if it does not already exist - FILE="${TPCH_DIR}/supplier.tbl" - if test -f "${FILE}"; then - echo " tbl files exist ($FILE exists)." - else - echo " creating tbl files with tpch_dbgen..." - docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s "${SCALE_FACTOR}" + # check if tpchgen-cli is installed + if ! command -v tpchgen-cli &> /dev/null + then + echo "tpchgen-cli could not be found, please install it via 'cargo install tpchgen-cli'" + exit 1 fi # Copy expected answers into the ./data/answers directory if it does not already exist @@ -566,27 +601,52 @@ data_tpch() { docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi - # Create 'parquet' files from tbl - FILE="${TPCH_DIR}/supplier" - if test -d "${FILE}"; then - echo " parquet files exist ($FILE exists)." - else - echo " creating parquet files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet - popd > /dev/null + if [ "$FORMAT" = "parquet" ]; then + # Create 'parquet' files, one directory per file + FILE="${TPCH_DIR}/supplier" + if test -d "${FILE}"; then + echo " parquet files exist ($FILE exists)." + else + echo " creating parquet files using tpchgen-cli ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format parquet --parquet-compression='ZSTD(1)' --parts=1 --output-dir "${TPCH_DIR}" + fi + return fi - # Create 'csv' files from tbl - FILE="${TPCH_DIR}/csv/supplier" - if test -d "${FILE}"; then - echo " csv files exist ($FILE exists)." - else - echo " creating csv files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv - popd > /dev/null + # Create 'csv' files, one directory per file + if [ "$FORMAT" = "csv" ]; then + FILE="${TPCH_DIR}/csv/supplier" + if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." + else + echo " creating csv files using tpchgen-cli binary ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format csv --parts=1 --output-dir "${TPCH_DIR}/csv" + fi + return + fi + + echo "Error: unknown format '$FORMAT' for tpch data generation, expected 'parquet' or 'csv'" + exit 1 +} + +# Downloads TPC-DS data +data_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" + + # Check if `web_site.parquet` exists in the TPCDS data directory to verify data presence + echo "Checking TPC-DS data directory: ${TPCDS_DIR}" + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + mkdir -p "${TPCDS_DIR}" + # Download the DataFusion benchmarks repository zip if it is not already downloaded + if [ ! -f "${DATA_DIR}/datafusion-benchmarks.zip" ]; then + echo "Downloading DataFusion benchmarks repository zip to: ${DATA_DIR}/datafusion-benchmarks.zip" + wget --timeout=30 --tries=3 -O "${DATA_DIR}/datafusion-benchmarks.zip" https://github.com/apache/datafusion-benchmarks/archive/refs/heads/main.zip + fi + echo "Extracting TPC-DS parquet data to ${TPCDS_DIR}..." + unzip -o -j -d "${TPCDS_DIR}" "${DATA_DIR}/datafusion-benchmarks.zip" datafusion-benchmarks-main/tpcds/data/sf1/* + echo "TPC-DS data extracted." fi + echo "Done." } # Runs the tpch benchmark @@ -603,10 +663,10 @@ run_tpch() { echo "Running tpch benchmark..." FORMAT=$2 - debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- tpch --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} } -# Runs the tpch in memory +# Runs the tpch in memory (needs tpch parquet data) run_tpch_mem() { SCALE_FACTOR=$1 if [ -z "$SCALE_FACTOR" ] ; then @@ -619,7 +679,27 @@ run_tpch_mem() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." # -m means in memory - debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} + debug_run $CARGO_COMMAND --bin dfbench -- tpch --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} +} + +# Runs the tpcds benchmark +run_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" + + # Check if TPCDS data directory and representative file exists + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + echo "" >&2 + echo "Please prepare TPC-DS data first by following instructions:" >&2 + echo " ./bench.sh data tpcds" >&2 + echo "" >&2 + exit 1 + fi + + RESULTS_FILE="${RESULTS_DIR}/tpcds_sf1.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running tpcds benchmark..." + + debug_run $CARGO_COMMAND --bin dfbench -- tpcds --iterations 5 --path "${TPCDS_DIR}" --query_path "../datafusion/core/tests/tpc-ds" --prefer_hash_join "${PREFER_HASH_JOIN}" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the compile profile benchmark helper @@ -1154,6 +1234,14 @@ run_hj() { debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} } +# Runs the smj benchmark +run_smj() { + RESULTS_FILE="${RESULTS_DIR}/smj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running smj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- smj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} +} + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" @@ -1189,6 +1277,113 @@ compare_benchmarks() { } +# Creates sorted ClickBench data from hits.parquet (full dataset) +# The data is sorted by EventTime in ascending order +# Uses datafusion-cli to reduce dependencies +clickbench_sorted() { + SORTED_FILE="${DATA_DIR}/hits_sorted.parquet" + ORIGINAL_FILE="${DATA_DIR}/hits.parquet" + + # Default memory limit is 12GB, can be overridden with DATAFUSION_MEMORY_GB env var + MEMORY_LIMIT_GB=${DATAFUSION_MEMORY_GB:-12} + + echo "Creating sorted ClickBench dataset from hits.parquet..." + echo "Configuration:" + echo " Memory limit: ${MEMORY_LIMIT_GB}G" + echo " Row group size: 64K rows" + echo " Compression: uncompressed" + + if [ ! -f "${ORIGINAL_FILE}" ]; then + echo "hits.parquet not found. Running data_clickbench_1 first..." + data_clickbench_1 + fi + + if [ -f "${SORTED_FILE}" ]; then + echo "Sorted hits.parquet already exists at ${SORTED_FILE}" + return 0 + fi + + echo "Sorting hits.parquet by EventTime (this may take several minutes)..." + + pushd "${DATAFUSION_DIR}" > /dev/null + echo "Building datafusion-cli..." + cargo build --release --bin datafusion-cli + DATAFUSION_CLI="${DATAFUSION_DIR}/target/release/datafusion-cli" + popd > /dev/null + + + START_TIME=$(date +%s) + echo "Start time: $(date '+%Y-%m-%d %H:%M:%S')" + echo "Using datafusion-cli to create sorted parquet file..." + "${DATAFUSION_CLI}" << EOF +-- Memory and performance configuration +SET datafusion.runtime.memory_limit = '${MEMORY_LIMIT_GB}G'; +SET datafusion.execution.spill_compression = 'uncompressed'; +SET datafusion.execution.sort_spill_reservation_bytes = 10485760; -- 10MB +SET datafusion.execution.batch_size = 8192; +SET datafusion.execution.target_partitions = 1; + +-- Parquet output configuration +SET datafusion.execution.parquet.max_row_group_size = 65536; +SET datafusion.execution.parquet.compression = 'uncompressed'; + +-- Execute sort and write +COPY (SELECT * FROM '${ORIGINAL_FILE}' ORDER BY "EventTime") +TO '${SORTED_FILE}' +STORED AS PARQUET; +EOF + + local result=$? + + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "End time: $(date '+%Y-%m-%d %H:%M:%S')" + + if [ $result -eq 0 ]; then + echo "✓ Successfully created sorted ClickBench dataset" + + INPUT_SIZE=$(stat -f%z "${ORIGINAL_FILE}" 2>/dev/null || stat -c%s "${ORIGINAL_FILE}" 2>/dev/null) + OUTPUT_SIZE=$(stat -f%z "${SORTED_FILE}" 2>/dev/null || stat -c%s "${SORTED_FILE}" 2>/dev/null) + INPUT_MB=$((INPUT_SIZE / 1024 / 1024)) + OUTPUT_MB=$((OUTPUT_SIZE / 1024 / 1024)) + + echo " Input: ${INPUT_MB} MB" + echo " Output: ${OUTPUT_MB} MB" + + echo "" + echo "Time Statistics:" + echo " Total duration: ${DURATION} seconds ($(printf '%02d:%02d:%02d' $((DURATION/3600)) $((DURATION%3600/60)) $((DURATION%60))))" + echo " Throughput: $((INPUT_MB / DURATION)) MB/s" + + return 0 + else + echo "✗ Error: Failed to create sorted dataset" + echo "💡 Tip: Try increasing memory with: DATAFUSION_MEMORY_GB=16 ./bench.sh data clickbench_sorted" + return 1 + fi +} + +# Runs the sorted data benchmark with prefer_existing_sort configuration +run_clickbench_sorted() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_sorted.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sorted data benchmark with prefer_existing_sort optimization..." + + # Ensure sorted data exists + clickbench_sorted + + # Run benchmark with prefer_existing_sort configuration + # This allows DataFusion to optimize away redundant sorts while maintaining parallelism + debug_run $CARGO_COMMAND --bin dfbench -- clickbench \ + --iterations 5 \ + --path "${DATA_DIR}/hits_sorted.parquet" \ + --queries-path "${SCRIPT_DIR}/queries/clickbench/queries/sorted_data" \ + --sorted-by "EventTime" \ + -c datafusion.optimizer.prefer_existing_sort=true \ + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} +} + setup_venv() { python3 -m venv "$VIRTUAL_ENV" PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt diff --git a/benchmarks/compare_tpcds.sh b/benchmarks/compare_tpcds.sh new file mode 100755 index 0000000000000..48331a7c7510e --- /dev/null +++ b/benchmarks/compare_tpcds.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-DS benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + echo "" + echo "Note: TPC-DS benchmarks are not currently implemented in bench.sh" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-DS benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpcds + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpcds + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/compare_tpch.sh b/benchmarks/compare_tpch.sh new file mode 100755 index 0000000000000..85e8da29ce41d --- /dev/null +++ b/benchmarks/compare_tpch.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-H benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-H benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpch + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpch + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/queries/clickbench/queries/sorted_data/q0.sql b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql new file mode 100644 index 0000000000000..1170a383bcb22 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql @@ -0,0 +1,3 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true +SELECT * FROM hits ORDER BY "EventTime" DESC limit 10; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 816cae0e38555..d842d306c1f65 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -34,7 +34,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, hj, imdb, nlj, sort_tpch, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, smj, sort_tpch, tpcds, tpch, }; #[derive(Debug, StructOpt)] @@ -46,9 +46,10 @@ enum Options { HJ(hj::RunOpt), Imdb(imdb::RunOpt), Nlj(nlj::RunOpt), + Smj(smj::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), - TpchConvert(tpch::ConvertOpt), + Tpcds(tpcds::RunOpt), } // Main benchmark runner entrypoint @@ -63,8 +64,9 @@ pub async fn main() -> Result<()> { Options::HJ(opt) => opt.run().await, Options::Imdb(opt) => Box::pin(opt.run()).await, Options::Nlj(opt) => opt.run().await, + Options::Smj(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, Options::Tpch(opt) => Box::pin(opt.run()).await, - Options::TpchConvert(opt) => opt.run().await, + Options::Tpcds(opt) => Box::pin(opt.run()).await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 46b6cc9a80b24..2bc2bd4458a53 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -33,17 +33,17 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; use datafusion::execution::memory_pool::FairSpillPool; -use datafusion::execution::memory_pool::{human_readable_size, units}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, exec_err}; +use datafusion_common::{human_readable_size, units}; #[derive(Debug, StructOpt)] #[structopt( diff --git a/benchmarks/src/bin/mem_profile.rs b/benchmarks/src/bin/mem_profile.rs index 16fc3871bec86..025efefe062e1 100644 --- a/benchmarks/src/bin/mem_profile.rs +++ b/benchmarks/src/bin/mem_profile.rs @@ -199,21 +199,18 @@ fn run_query(args: &[String], results: &mut Vec) -> Result<()> { // Look for lines that contain execution time / memory stats while let Some(line) = iter.next() { - if let Some((query, duration_ms)) = parse_query_time(line) { - if let Some(next_line) = iter.peek() { - if let Some((peak_rss, peak_commit, page_faults)) = - parse_vm_line(next_line) - { - results.push(QueryResult { - query, - duration_ms, - peak_rss, - peak_commit, - page_faults, - }); - break; - } - } + if let Some((query, duration_ms)) = parse_query_time(line) + && let Some(next_line) = iter.peek() + && let Some((peak_rss, peak_commit, page_faults)) = parse_vm_line(next_line) + { + results.push(QueryResult { + query, + duration_ms, + peak_rss, + peak_commit, + page_faults, + }); + break; } } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs deleted file mode 100644 index ca2bb8e57c0ec..0000000000000 --- a/benchmarks/src/bin/tpch.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! tpch binary only entrypoint - -use datafusion::error::Result; -use datafusion_benchmarks::tpch; -use structopt::StructOpt; - -#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] -compile_error!( - "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" -); - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[cfg(feature = "mimalloc")] -#[global_allocator] -static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; - -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] -enum BenchmarkSubCommandOpt { - #[structopt(name = "datafusion")] - DataFusionBenchmark(tpch::RunOpt), -} - -#[derive(Debug, StructOpt)] -#[structopt(name = "TPC-H", about = "TPC-H Benchmarks.")] -enum TpchOpt { - Benchmark(BenchmarkSubCommandOpt), - Convert(tpch::ConvertOpt), -} - -/// 'tpch' entry point, with tortured command line arguments. Please -/// use `dbbench` instead. -/// -/// Note: this is kept to be backwards compatible with the benchmark names prior to -/// -#[tokio::main] -async fn main() -> Result<()> { - env_logger::init(); - match TpchOpt::from_args() { - TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - Box::pin(opt.run()).await - } - TpchOpt::Convert(opt) => opt.run().await, - } -} diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index fcf03fbc54550..1b4c04b409ccd 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -25,22 +25,22 @@ use arrow::array::Array; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use datafusion::common::{Result, ScalarValue}; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; -use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::execution::TaskContext; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::prelude::*; use datafusion_common::instant::Instant; use futures::TryStreamExt; use object_store::ObjectStore; -use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; +use parquet::arrow::async_writer::ParquetObjectWriter; +use rand::Rng; use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; -use rand::Rng; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio_util::sync::CancellationToken; diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index a550503390c54..9036e7d9501ec 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -19,7 +19,7 @@ use std::fs; use std::io::ErrorKind; use std::path::{Path, PathBuf}; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{ error::{DataFusionError, Result}, @@ -78,6 +78,27 @@ pub struct RunOpt { /// If present, write results json here #[structopt(parse(from_os_str), short = "o", long = "output")] output_path: Option, + + /// Column name that the data is sorted by (e.g., "EventTime") + /// If specified, DataFusion will be informed that the data has this sort order + /// using CREATE EXTERNAL TABLE with WITH ORDER clause. + /// + /// Recommended to use with: -c datafusion.optimizer.prefer_existing_sort=true + /// This allows DataFusion to optimize away redundant sorts while maintaining + /// multi-core parallelism for other operations. + #[structopt(long = "sorted-by")] + sorted_by: Option, + + /// Sort order: ASC or DESC (default: ASC) + #[structopt(long = "sort-order", default_value = "ASC")] + sort_order: String, + + /// Configuration options in the format key=value + /// Can be specified multiple times. + /// + /// Example: -c datafusion.optimizer.prefer_existing_sort=true + #[structopt(short = "c", long = "config")] + config_options: Vec, } /// Get the SQL file path @@ -125,6 +146,39 @@ impl RunOpt { // configure parquet options let mut config = self.common.config()?; + + if self.sorted_by.is_some() { + println!("ℹ️ Data is registered with sort order"); + + let has_prefer_sort = self + .config_options + .iter() + .any(|opt| opt.contains("prefer_existing_sort=true")); + + if !has_prefer_sort { + println!( + "ℹ️ Consider using -c datafusion.optimizer.prefer_existing_sort=true" + ); + println!("ℹ️ to optimize queries while maintaining parallelism"); + } + } + + // Apply user-provided configuration options + for config_opt in &self.config_options { + let parts: Vec<&str> = config_opt.splitn(2, '=').collect(); + if parts.len() != 2 { + return Err(exec_datafusion_err!( + "Invalid config option format: '{}'. Expected 'key=value'", + config_opt + )); + } + let key = parts[0]; + let value = parts[1]; + + println!("Setting config: {key} = {value}"); + config = config.set_str(key, value); + } + { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns @@ -136,10 +190,18 @@ impl RunOpt { parquet_options.pushdown_filters = true; parquet_options.reorder_filters = true; } + + if self.sorted_by.is_some() { + // We should compare the dynamic topk optimization when data is sorted, so we make the + // assumption that filter pushdown is also enabled in this case. + parquet_options.pushdown_filters = true; + parquet_options.reorder_filters = true; + } } let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + self.register_hits(&ctx).await?; let mut benchmark_run = BenchmarkRun::new(); @@ -214,17 +276,54 @@ impl RunOpt { } /// Registers the `hits.parquet` as a table named `hits` + /// If sorted_by is specified, uses CREATE EXTERNAL TABLE with WITH ORDER async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { - let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); - ctx.register_parquet("hits", path, options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'hits' as {path}"), - Box::new(e), - ) - }) + + // If sorted_by is specified, use CREATE EXTERNAL TABLE with WITH ORDER + if let Some(ref sort_column) = self.sorted_by { + println!( + "Registering table with sort order: {} {}", + sort_column, self.sort_order + ); + + // Escape column name with double quotes + let escaped_column = if sort_column.contains('"') { + sort_column.clone() + } else { + format!("\"{sort_column}\"") + }; + + // Build CREATE EXTERNAL TABLE DDL with WITH ORDER clause + // Schema will be automatically inferred from the Parquet file + let create_table_sql = format!( + "CREATE EXTERNAL TABLE hits \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER ({} {})", + path, + escaped_column, + self.sort_order.to_uppercase() + ); + + println!("Executing: {create_table_sql}"); + + // Execute the CREATE EXTERNAL TABLE statement + ctx.sql(&create_table_sql).await?.collect().await?; + + Ok(()) + } else { + // Original registration without sort order + let options = Default::default(); + ctx.register_parquet("hits", path, options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'hits' as {path}"), + Box::new(e), + ) + }) + } } fn iterations(&self) -> usize { diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index be74252031194..07a40447d4149 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -20,11 +20,11 @@ //! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) //! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, print_memory_stats}; use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::{ - exec_datafusion_err, instant::Instant, internal_err, DataFusionError, TableReference, + DataFusionError, TableReference, exec_datafusion_err, instant::Instant, internal_err, }; use std::path::{Path, PathBuf}; use structopt::StructOpt; diff --git a/benchmarks/src/hj.rs b/benchmarks/src/hj.rs index 505b322745485..562047f615bc8 100644 --- a/benchmarks/src/hj.rs +++ b/benchmarks/src/hj.rs @@ -19,7 +19,7 @@ use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion::physical_plan::execute_stream; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::instant::Instant; -use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; use structopt::StructOpt; use futures::StreamExt; diff --git a/benchmarks/src/imdb/convert.rs b/benchmarks/src/imdb/convert.rs index e7949aa715c23..2c4e1270255bb 100644 --- a/benchmarks/src/imdb/convert.rs +++ b/benchmarks/src/imdb/convert.rs @@ -26,8 +26,8 @@ use structopt::StructOpt; use datafusion::common::not_impl_err; -use super::get_imdb_table_schema; use super::IMDB_TABLES; +use super::get_imdb_table_schema; #[derive(Debug, StructOpt)] pub struct ConvertOpt { diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 11bd424ba6866..05f1870c5d45a 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -19,16 +19,16 @@ use std::path::PathBuf; use std::sync::Arc; use super::{ - get_imdb_table_schema, get_query_sql, IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, - IMDB_TABLES, + IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, IMDB_TABLES, get_imdb_table_schema, + get_query_sql, }; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 07cffa5ae468e..a3bc221840ada 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -22,6 +22,8 @@ pub mod h2o; pub mod hj; pub mod imdb; pub mod nlj; +pub mod smj; pub mod sort_tpch; +pub mod tpcds; pub mod tpch; pub mod util; diff --git a/benchmarks/src/nlj.rs b/benchmarks/src/nlj.rs index 7d1e14f69439c..cbf5a03fbf93d 100644 --- a/benchmarks/src/nlj.rs +++ b/benchmarks/src/nlj.rs @@ -19,7 +19,7 @@ use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion::physical_plan::execute_stream; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::instant::Instant; -use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; use structopt::StructOpt; use futures::StreamExt; @@ -268,8 +268,8 @@ impl RunOpt { let elapsed = start.elapsed(); println!( - "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" - ); + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); query_results.push(QueryResult { elapsed, row_count }); } diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs new file mode 100644 index 0000000000000..53902e09302c2 --- /dev/null +++ b/benchmarks/src/smj.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; +use structopt::StructOpt; + +use futures::StreamExt; + +/// Run the Sort Merge Join (SMJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of SMJs. +/// +/// It uses equality join predicates (to ensure SMJ is selected) and varies: +/// - Join type: Inner/Left/Right/Full/LeftSemi/LeftAnti/RightSemi/RightAnti +/// - Key cardinality: 1:1, 1:N, N:M relationships +/// - Filter selectivity: Low (1%), Medium (10%), High (50%) +/// - Input sizes: Small to large, balanced and skewed +/// +/// All inputs are pre-sorted in CTEs before the join to isolate join +/// performance from sort overhead. +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 20). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +/// Inline SQL queries for SMJ benchmarks +/// +/// Each query's comment includes: +/// - Join type +/// - Left row count × Right row count +/// - Key cardinality (rows per key) +/// - Filter selectivity (if applicable) +const SMJ_QUERIES: &[&str] = &[ + // Q1: INNER 100K x 100K | 1:1 + r#" + WITH t1_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t2_sorted.key as k2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q2: INNER 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q3: INNER 1M x 1M | 1:100 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q4: INNER 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data % 100 = 0 + "#, + // Q5: INNER 1M x 1M | 1:100 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t1_sorted.data <> t2_sorted.data AND t2_sorted.data % 10 = 0 + "#, + // Q6: LEFT 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q7: LEFT 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 + "#, + // Q8: FULL 100K x 100K | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 12500 as key, value as data + FROM range(100000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q9: FULL 100K x 1M | 1:10 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE (t1_sorted.data IS NULL OR t2_sorted.data IS NULL + OR t1_sorted.data <> t2_sorted.data) + AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) + "#, + // Q10: LEFT SEMI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q11: LEFT SEMI 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q12: LEFT SEMI 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q13: LEFT SEMI 100K x 1M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data % 10 <> 0 + ) + "#, + // Q14: LEFT ANTI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q15: LEFT ANTI 100K x 1M | 1:10 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 12000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q16: LEFT ANTI 100K x 100K | 1:1 | stress + r#" + WITH t1_sorted AS ( + SELECT value % 11000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(100000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q17: INNER 100K x 5M | 1:50 | 5% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data <> t1_sorted.data AND t2_sorted.data % 20 = 0 + "#, + // Q18: LEFT SEMI 100K x 5M | 1:50 | 2% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 50 = 0 + ) + "#, + // Q19: LEFT ANTI 100K x 5M | 1:50 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 15000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(5000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q20: INNER 1M x 10M | 1:100 + GROUP BY + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, count(*) as cnt + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + GROUP BY t1_sorted.key + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running SMJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= SMJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + SMJ_QUERIES.len() + ); + } + } + None => 1..=SMJ_QUERIES.len(), + }; + + let mut config = self.common.config()?; + // Disable hash joins to force SMJ + config = config.set_bool("datafusion.optimizer.prefer_hash_join", false); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = SMJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("SMJ benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Sort Merge Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("SortMergeJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Sort Merge Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 09b5a676bbff1..2f3be76f050b9 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -36,11 +36,11 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion::prelude::*; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; #[derive(Debug, StructOpt)] pub struct RunOpt { diff --git a/datafusion/core/tests/schema_adapter/mod.rs b/benchmarks/src/tpcds/mod.rs similarity index 95% rename from datafusion/core/tests/schema_adapter/mod.rs rename to benchmarks/src/tpcds/mod.rs index 2f81a43f4736e..4829eb9fd348a 100644 --- a/datafusion/core/tests/schema_adapter/mod.rs +++ b/benchmarks/src/tpcds/mod.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. -mod schema_adapter_integration_tests; +mod run; +pub use run::RunOpt; diff --git a/benchmarks/src/tpcds/run.rs b/benchmarks/src/tpcds/run.rs new file mode 100644 index 0000000000000..3f579024ba519 --- /dev/null +++ b/benchmarks/src/tpcds/run.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, plan_err}; + +use log::info; +use structopt::StructOpt; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; +pub const TPCDS_QUERY_START_ID: usize = 1; +pub const TPCDS_QUERY_END_ID: usize = 99; + +pub const TPCDS_TABLES: &[&str] = &[ + "call_center", + "customer_address", + "household_demographics", + "promotion", + "store_sales", + "web_page", + "catalog_page", + "customer_demographics", + "income_band", + "reason", + "store", + "web_returns", + "catalog_returns", + "customer", + "inventory", + "ship_mode", + "time_dim", + "web_sales", + "catalog_sales", + "date_dim", + "item", + "store_returns", + "warehouse", + "web_site", +]; + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(base_query_path: &str, query: usize) -> Result> { + if query > 0 && query < 100 { + let filename = format!("{base_query_path}/{query}.sql"); + let mut errors = vec![]; + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + + plan_err!("invalid query. Could not find query: {:?}", errors) + } else { + plan_err!("invalid query. Expected value between 1 and 99") + } +} + +/// Run the tpcds benchmark. +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + pub query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Path to query files + #[structopt(parse(from_os_str), required = true, short = "Q", long = "query_path")] + query_path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[structopt(short = "S", long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, + + /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join + /// False by default. + #[structopt( + short = "w", + long = "enable_piecewise_merge_join", + default_value = "false" + )] + enable_piecewise_merge_join: BoolDefaultTrue, + + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[structopt(short = "t", long = "sorted")] + sorted: bool, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => TPCDS_QUERY_START_ID..=TPCDS_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + let mut config = self + .common + .config()? + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + config.options_mut().optimizer.enable_piecewise_merge_join = + self.enable_piecewise_merge_join; + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + // register tables + self.register_tables(&ctx).await?; + + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + + let sql = &get_query_sql(self.query_path.to_str().unwrap(), query_id)?; + + if self.common.debug { + println!("=== SQL for query {query_id} ===\n{}\n", sql.join(";\n")); + } + + for i in 0..self.iterations() { + let start = Instant::now(); + + // query 15 is special, with 3 statements. the second statement is the one from which we + // want to capture the results + let mut result = vec![]; + + for query in sql { + result = self.execute_query(ctx, query).await?; + } + + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in TPCDS_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let target_partitions = self.partitions(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}.parquet"); + + // Check if the file exists + if !std::path::Path::new(&path).exists() { + eprintln!("Warning registering {table}: Table file does not exist: {path}"); + } + + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + + let table_path = ListingTableUrl::parse(path)?; + let options = ListingOptions::new(Arc::new(format)) + .with_file_extension(DEFAULT_PARQUET_EXTENSION) + .with_target_partitions(target_partitions) + .with_collect_stat(state.config().collect_statistics()); + let schema = options.infer_schema(&state, &table_path).await?; + + if self.common.debug { + println!( + "Inferred schema from {table_path} for table '{table}':\n{schema:#?}\n" + ); + } + + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or_else(get_available_parallelism) + } +} diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs deleted file mode 100644 index 5219e09cd3052..0000000000000 --- a/benchmarks/src/tpch/convert.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::logical_expr::select_expr::SelectExpr; -use datafusion_common::instant::Instant; -use std::fs; -use std::path::{Path, PathBuf}; - -use datafusion::common::not_impl_err; - -use super::get_tbl_tpch_table_schema; -use super::TPCH_TABLES; -use datafusion::error::Result; -use datafusion::prelude::*; -use parquet::basic::Compression; -use parquet::file::properties::WriterProperties; -use structopt::StructOpt; - -/// Convert tpch .slt files to .parquet or .csv files -#[derive(Debug, StructOpt)] -pub struct ConvertOpt { - /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] - input_path: PathBuf, - - /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] - output_path: PathBuf, - - /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] - file_format: String, - - /// Compression to use when writing Parquet files - #[structopt(short = "c", long = "compression", default_value = "zstd")] - compression: String, - - /// Number of partitions to produce - #[structopt(short = "n", long = "partitions", default_value = "1")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, - - /// Sort each table by its first column in ascending order. - #[structopt(short = "t", long = "sort")] - sort: bool, -} - -impl ConvertOpt { - pub async fn run(self) -> Result<()> { - let compression = self.compression()?; - - let input_path = self.input_path.to_str().unwrap(); - let output_path = self.output_path.to_str().unwrap(); - - let output_root_path = Path::new(output_path); - for table in TPCH_TABLES { - let start = Instant::now(); - let schema = get_tbl_tpch_table_schema(table); - let key_column_name = schema.fields()[0].name(); - - let input_path = format!("{input_path}/{table}.tbl"); - let options = CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .delimiter(b'|') - .file_extension(".tbl"); - let options = if self.sort { - // indicated that the file is already sorted by its first column to speed up the conversion - options - .file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) - } else { - options - }; - - let config = SessionConfig::new().with_batch_size(self.batch_size); - let ctx = SessionContext::new_with_config(config); - - // build plan to read the TBL file - let mut csv = ctx.read_csv(&input_path, options).await?; - - // Select all apart from the padding column - let selection = csv - .schema() - .iter() - .take(schema.fields.len() - 1) - .map(Expr::from) - .map(SelectExpr::from) - .collect::>(); - - csv = csv.select(selection)?; - // optionally, repartition the file - let partitions = self.partitions; - if partitions > 1 { - csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? - } - let csv = if self.sort { - csv.sort_by(vec![col(key_column_name)])? - } else { - csv - }; - - // create the physical plan - let csv = csv.create_physical_plan().await?; - - let output_path = output_root_path.join(table); - let output_path = output_path.to_str().unwrap().to_owned(); - fs::create_dir_all(&output_path)?; - println!( - "Converting '{}' to {} files in directory '{}'", - &input_path, self.file_format, &output_path - ); - match self.file_format.as_str() { - "csv" => ctx.write_csv(csv, output_path).await?, - "parquet" => { - let props = WriterProperties::builder() - .set_compression(compression) - .build(); - ctx.write_parquet(csv, output_path, Some(props)).await? - } - other => { - return not_impl_err!("Invalid output format: {other}"); - } - } - println!("Conversion completed in {} ms", start.elapsed().as_millis()); - } - - Ok(()) - } - - /// return the compression method to use when writing parquet - fn compression(&self) -> Result { - Ok(match self.compression.as_str() { - "none" => Compression::UNCOMPRESSED, - "snappy" => Compression::SNAPPY, - "brotli" => Compression::BROTLI(Default::default()), - "gzip" => Compression::GZIP(Default::default()), - "lz4" => Compression::LZ4, - "lz0" => Compression::LZO, - "zstd" => Compression::ZSTD(Default::default()), - other => { - return not_impl_err!("Invalid compression format: {other}"); - } - }) - } -} diff --git a/benchmarks/src/tpch/mod.rs b/benchmarks/src/tpch/mod.rs index 233ea94a05c1a..681aa0a403ee1 100644 --- a/benchmarks/src/tpch/mod.rs +++ b/benchmarks/src/tpch/mod.rs @@ -27,9 +27,6 @@ use std::fs; mod run; pub use run::RunOpt; -mod convert; -pub use convert::ConvertOpt; - pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index cc59b78030360..65bb9594f00a6 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -19,16 +19,16 @@ use std::path::PathBuf; use std::sync::Arc; use super::{ - get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_QUERY_END_ID, - TPCH_QUERY_START_ID, TPCH_TABLES, + TPCH_QUERY_END_ID, TPCH_QUERY_START_ID, TPCH_TABLES, get_query_sql, + get_tbl_tpch_table_schema, get_tpch_table_schema, }; -use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -93,9 +93,9 @@ pub struct RunOpt { prefer_hash_join: BoolDefaultTrue, /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join - /// True by default. + /// False by default. #[structopt( - short = "j", + short = "w", long = "enable_piecewise_merge_join", default_value = "false" )] diff --git a/benchmarks/src/util/memory.rs b/benchmarks/src/util/memory.rs index 944239df31cfd..11b96ef227756 100644 --- a/benchmarks/src/util/memory.rs +++ b/benchmarks/src/util/memory.rs @@ -19,7 +19,7 @@ pub fn print_memory_stats() { #[cfg(all(feature = "mimalloc", feature = "mimalloc_extended"))] { - use datafusion::execution::memory_pool::human_readable_size; + use datafusion_common::human_readable_size; let mut peak_rss = 0; let mut peak_commit = 0; let mut page_faults = 0; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 6627a287dfcd4..b1d5bc99fb406 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -105,7 +105,7 @@ impl CommonOpt { return Err(DataFusionError::Configuration(format!( "Invalid memory pool type: {}", self.mem_pool_type - ))) + ))); } }; rt_builder = rt_builder diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 764ea648ff725..df17674e62961 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{error::Result, DATAFUSION_VERSION}; +use datafusion::{DATAFUSION_VERSION, error::Result}; use datafusion_common::utils::get_available_parallelism; use serde::{Serialize, Serializer}; use serde_json::Value; diff --git a/ci/scripts/check_examples_docs.sh b/ci/scripts/check_examples_docs.sh new file mode 100755 index 0000000000000..37b0cc088df4c --- /dev/null +++ b/ci/scripts/check_examples_docs.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +EXAMPLES_DIR="datafusion-examples/examples" +README="datafusion-examples/README.md" + +# ffi examples are skipped because they were not part of the recent example +# consolidation work and do not follow the new grouping and execution pattern. +# They are not documented in the README using the new structure, so including +# them here would cause false CI failures. +SKIP_LIST=("ffi") + +missing=0 + +skip() { + local value="$1" + for item in "${SKIP_LIST[@]}"; do + if [[ "$item" == "$value" ]]; then + return 0 + fi + done + return 1 +} + +# collect folder names +folders=$(find "$EXAMPLES_DIR" -mindepth 1 -maxdepth 1 -type d -exec basename {} \;) + +# collect group names from README headers +groups=$(grep "^### Group:" "$README" | sed -E 's/^### Group: `([^`]+)`.*/\1/') + +for folder in $folders; do + if skip "$folder"; then + echo "Skipped group: $folder" + continue + fi + + if ! echo "$groups" | grep -qx "$folder"; then + echo "Missing README entry for example group: $folder" + missing=1 + fi +done + +if [[ $missing -eq 1 ]]; then + echo "README is out of sync with examples" + exit 1 +fi diff --git a/ci/scripts/doc_prettier_check.sh b/ci/scripts/doc_prettier_check.sh new file mode 100755 index 0000000000000..d94a0d1c96171 --- /dev/null +++ b/ci/scripts/doc_prettier_check.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/$(basename "${BASH_SOURCE[0]}")" + +MODE="--check" +ACTION="Checking" +if [ $# -gt 0 ]; then + if [ "$1" = "--write" ]; then + MODE="--write" + ACTION="Formatting" + else + echo "Usage: $0 [--write]" >&2 + exit 1 + fi +fi + +echo "$SCRIPT_PATH: $ACTION documents with prettier" + +# Ensure `npx` is available +if ! command -v npx >/dev/null 2>&1; then + echo "npx is required to run the prettier check. Install Node.js (e.g., brew install node) and re-run." >&2 + exit 1 +fi + +# Ignore subproject CHANGELOG.md because it is machine generated +npx prettier@2.7.1 $MODE \ + '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ + '!datafusion/CHANGELOG.md' \ + README.md \ + CONTRIBUTING.md +status=$? + +if [ $status -ne 0 ]; then + if [ "$MODE" = "--check" ]; then + echo "Prettier check failed. Re-run with --write (e.g., ./ci/scripts/doc_prettier_check.sh --write) to format files, commit the changes, and re-run the check." >&2 + else + echo "Prettier format failed. Files may have been modified; commit any changes and re-run." >&2 + fi + exit $status +fi diff --git a/ci/scripts/rust_clippy.sh b/ci/scripts/rust_clippy.sh index 6a00ad8109561..aa994bc2b8c8a 100755 --- a/ci/scripts/rust_clippy.sh +++ b/ci/scripts/rust_clippy.sh @@ -18,4 +18,4 @@ # under the License. set -ex -cargo clippy --all-targets --workspace --features avro,pyarrow,integration-tests,extended_tests -- -D warnings \ No newline at end of file +cargo clippy --all-targets --workspace --features avro,integration-tests,extended_tests -- -D warnings diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index c3efcf2cf2e92..7a5f7825b4e6d 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -25,12 +25,26 @@ export CARGO_PROFILE_CI_STRIP=true cd datafusion-examples/examples/ cargo build --profile ci --examples -files=$(ls .) -for filename in $files -do - example_name=`basename $filename ".rs"` - # Skip tests that rely on external storage and flight - if [ ! -d $filename ]; then - cargo run --profile ci --example $example_name - fi +SKIP_LIST=("external_dependency" "flight" "ffi") + +skip_example() { + local name="$1" + for skip in "${SKIP_LIST[@]}"; do + if [ "$name" = "$skip" ]; then + return 0 + fi + done + return 1 +} + +for dir in */; do + example_name=$(basename "$dir") + + if skip_example "$example_name"; then + echo "Skipping $example_name" + continue + fi + + echo "Running example group: $example_name" + cargo run --profile ci --example "$example_name" -- all done diff --git a/ci/scripts/typos_check.sh b/ci/scripts/typos_check.sh new file mode 100755 index 0000000000000..a3a4a893213f7 --- /dev/null +++ b/ci/scripts/typos_check.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex +# To use this script, you must have installed `typos`, for example: +# cargo install typos-cli --locked --version 1.37.0 +typos --config typos.toml diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f3069b492352d..67cb10081ca47 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,10 +37,10 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.8.7" +aws-config = "1.8.12" aws-credential-types = "1.2.7" chrono = { workspace = true } -clap = { version = "4.5.50", features = ["cargo", "derive"] } +clap = { version = "4.5.53", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs index bd2dbb736781f..6095072163870 100644 --- a/datafusion-cli/examples/cli-session-context.rs +++ b/datafusion-cli/examples/cli-session-context.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::{LogicalPlan, LogicalPlanBuilder}, prelude::SessionContext, }; diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 20d62eabc3901..63b055388fdbe 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -18,13 +18,13 @@ use std::any::Any; use std::sync::{Arc, Weak}; -use crate::object_storage::{get_object_store, AwsOptions, GcpOptions}; +use crate::object_storage::{AwsOptions, GcpOptions, get_object_store}; use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}; use datafusion::common::plan_datafusion_err; -use datafusion::datasource::listing::ListingTableUrl; use datafusion::datasource::TableProvider; +use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::Result; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; @@ -152,10 +152,10 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { async fn table(&self, name: &str) -> Result>> { let inner_table = self.inner.table(name).await; - if inner_table.is_ok() { - if let Some(inner_table) = inner_table? { - return Ok(Some(inner_table)); - } + if inner_table.is_ok() + && let Some(inner_table) = inner_table? + { + return Ok(Some(inner_table)); } // if the inner schema provider didn't have a table by @@ -219,12 +219,12 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { } pub fn substitute_tilde(cur: String) -> String { - if let Some(usr_dir_path) = home_dir() { - if let Some(usr_dir) = usr_dir_path.to_str() { - if cur.starts_with('~') && !usr_dir.is_empty() { - return cur.replacen('~', usr_dir, 1); - } - } + if let Some(usr_dir_path) = home_dir() + && let Some(usr_dir) = usr_dir_path.to_str() + && cur.starts_with('~') + && !usr_dir.is_empty() + { + return cur.replacen('~', usr_dir, 1); } cur } @@ -359,10 +359,12 @@ mod tests { } else { "/home/user" }; - env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - test_home_path, - ); + unsafe { + env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + test_home_path, + ); + } let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; let expected = PathBuf::from(test_home_path) .join("Code") @@ -376,12 +378,16 @@ mod tests { .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); - match original_home { - Some(home_path) => env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - home_path.to_str().unwrap(), - ), - None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }), + unsafe { + match original_home { + Some(home_path) => env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + home_path.to_str().unwrap(), + ), + None => { + env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }) + } + } } } } diff --git a/datafusion-cli/src/cli_context.rs b/datafusion-cli/src/cli_context.rs index 516929ebacf19..a6320f03fe4de 100644 --- a/datafusion-cli/src/cli_context.rs +++ b/datafusion-cli/src/cli_context.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::LogicalPlan, prelude::SessionContext, }; diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 3fbfe5680cfcd..8aaa8025d1c3a 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -19,7 +19,7 @@ use crate::cli_context::CliSessionContext; use crate::exec::{exec_and_print, exec_from_lines}; -use crate::functions::{display_all_functions, Function}; +use crate::functions::{Function, display_all_functions}; use crate::print_format::PrintFormat; use crate::print_options::PrintOptions; use clap::ValueEnum; diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index d079a88a6440e..2b8385ac2d89c 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -35,19 +35,19 @@ use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; use datafusion::physical_plan::spill::get_record_batch_memory_size; -use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; +use datafusion::physical_plan::{ExecutionPlanProperties, execute_stream}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser; use datafusion::sql::sqlparser::dialect::dialect_from_str; use futures::StreamExt; use log::warn; use object_store::Error::Generic; -use rustyline::error::ReadlineError; use rustyline::Editor; +use rustyline::error::ReadlineError; use std::collections::HashMap; use std::fs::File; -use std::io::prelude::*; use std::io::BufReader; +use std::io::prelude::*; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -153,7 +153,7 @@ pub async fn exec_from_repl( } } else { eprintln!( - "'\\{}' is not a valid command", + "'\\{}' is not a valid command, you can use '\\?' to see all commands", &line[1..] ); } @@ -168,7 +168,10 @@ pub async fn exec_from_repl( } } } else { - eprintln!("'\\{}' is not a valid command", &line[1..]); + eprintln!( + "'\\{}' is not a valid command, you can use '\\?' to see all commands", + &line[1..] + ); } } Ok(line) => { @@ -334,7 +337,9 @@ impl StatementExecutor { if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store)) && self.statement_for_retry.is_some() => { - warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."); + warn!( + "S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration." + ); let plan = create_plan(ctx, self.statement_for_retry.take().unwrap(), true) .await?; @@ -699,8 +704,7 @@ mod tests { #[tokio::test] async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -713,7 +717,9 @@ mod tests { assert!(err.to_string().contains("os error 2")); // for service_account_key - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'" + ); let err = create_external_table_test(location, &sql) .await .unwrap_err() @@ -748,8 +754,9 @@ mod tests { let location = "path/to/file.cvs"; // Test with format options - let sql = - format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')" + ); create_external_table_test(location, &sql).await.unwrap(); Ok(()) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index d23b12469e385..a45d57e8e952d 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -27,9 +27,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::catalog::{Session, TableFunctionImpl}; -use datafusion::common::{plan_err, Column}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::common::{Column, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::cache::cache_manager::CacheManager; use datafusion::logical_expr::Expr; @@ -581,3 +581,119 @@ impl TableFunctionImpl for MetadataCacheFunc { Ok(Arc::new(metadata_cache)) } } + +/// STATISTICS_CACHE table function +#[derive(Debug)] +struct StatisticsCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for StatisticsCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct StatisticsCacheFunc { + cache_manager: Arc, +} + +impl StatisticsCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for StatisticsCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("statistics_cache should have no arguments"); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + Field::new("num_rows", DataType::Utf8, false), + Field::new("num_columns", DataType::UInt64, false), + Field::new("table_size_bytes", DataType::Utf8, false), + Field::new("statistics_size_bytes", DataType::UInt64, false), + ])); + + // construct record batch from metadata + let mut path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut e_tag_arr = vec![]; + let mut version_arr = vec![]; + let mut num_rows_arr = vec![]; + let mut num_columns_arr = vec![]; + let mut table_size_bytes_arr = vec![]; + let mut statistics_size_bytes_arr = vec![]; + + if let Some(file_statistics_cache) = self.cache_manager.get_file_statistic_cache() + { + for (path, entry) in file_statistics_cache.list_entries() { + path_arr.push(path.to_string()); + file_modified_arr + .push(Some(entry.object_meta.last_modified.timestamp_millis())); + file_size_bytes_arr.push(entry.object_meta.size); + e_tag_arr.push(entry.object_meta.e_tag); + version_arr.push(entry.object_meta.version); + num_rows_arr.push(entry.num_rows.to_string()); + num_columns_arr.push(entry.num_columns as u64); + table_size_bytes_arr.push(entry.table_size_bytes.to_string()); + statistics_size_bytes_arr.push(entry.statistics_size_bytes as u64); + } + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(e_tag_arr)), + Arc::new(StringArray::from(version_arr)), + Arc::new(StringArray::from(num_rows_arr)), + Arc::new(UInt64Array::from(num_columns_arr)), + Arc::new(StringArray::from(table_size_bytes_arr)), + Arc::new(UInt64Array::from(statistics_size_bytes_arr)), + ], + )?; + + let statistics_cache = StatisticsCacheTable { schema, batch }; + Ok(Arc::new(statistics_cache)) + } +} diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 219637b3460e6..df7afc14048b9 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -67,7 +67,7 @@ impl CliHelper { return Ok(ValidationResult::Invalid(Some(format!( " 🤔 Invalid dialect: {}", self.dialect - )))) + )))); } }; let lines = split_from_semicolon(sql); @@ -121,10 +121,10 @@ impl Hinter for CliHelper { fn is_open_quote_for_location(line: &str, pos: usize) -> bool { let mut sql = line[..pos].to_string(); sql.push('\''); - if let Ok(stmts) = DFParser::parse_sql(&sql) { - if let Some(Statement::CreateExternalTable(_)) = stmts.back() { - return true; - } + if let Ok(stmts) = DFParser::parse_sql(&sql) + && let Some(Statement::CreateExternalTable(_)) = stmts.back() + { + return true; } false } diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index f4e57a2e3593a..912a13916a5bd 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -23,7 +23,7 @@ use std::{ }; use datafusion::sql::sqlparser::{ - dialect::{dialect_from_str, Dialect, GenericDialect}, + dialect::{Dialect, GenericDialect, dialect_from_str}, keywords::Keyword, tokenizer::{Token, Tokenizer}, }; @@ -94,8 +94,8 @@ impl Color { #[cfg(test)] mod tests { - use super::config::Dialect; use super::SyntaxHighlighter; + use super::config::Dialect; use rustyline::highlight::Highlighter; #[test] diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 09fa8ef15af84..8f69ae477904c 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -31,16 +31,17 @@ use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::logical_expr::ExplainFormat; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; -use datafusion_cli::functions::{MetadataCacheFunc, ParquetMetadataFunc}; +use datafusion_cli::functions::{ + MetadataCacheFunc, ParquetMetadataFunc, StatisticsCacheFunc, +}; use datafusion_cli::object_storage::instrumented::{ InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, }; use datafusion_cli::{ - exec, + DATAFUSION_CLI_VERSION, exec, pool_type::PoolType, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, - DATAFUSION_CLI_VERSION, }; use clap::Parser; @@ -244,6 +245,14 @@ async fn main_inner() -> Result<()> { )), ); + // register `statistics_cache` table function to get the contents of the file statistics cache + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, @@ -423,7 +432,13 @@ pub fn extract_disk_limit(size: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion::{common::test_util::batches_to_string, prelude::ParquetReadOptions}; + use datafusion::{ + common::test_util::batches_to_string, + execution::cache::{ + cache_manager::CacheManagerConfig, cache_unit::DefaultFileStatisticsCache, + }, + prelude::ParquetReadOptions, + }; use insta::assert_snapshot; fn assert_conversion(input: &str, expected: Result) { @@ -488,8 +503,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with single quote - let sql = - "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -502,8 +516,7 @@ mod tests { "#); // input with double quote - let sql = - "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let sql = "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; assert_snapshot!(batches_to_string(&rbs), @r#" @@ -523,8 +536,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with string columns - let sql = - "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -592,9 +604,9 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 6957 | 2 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 996 | 2 | page_index=false | + | alltypes_plain.parquet | 1851 | 8882 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269266 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1347 | 2 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); @@ -623,12 +635,110 @@ mod tests { +-----------------------------------+-----------------+---------------------+------+------------------+ | filename | file_size_bytes | metadata_size_bytes | hits | extra | +-----------------------------------+-----------------+---------------------+------+------------------+ - | alltypes_plain.parquet | 1851 | 6957 | 5 | page_index=false | - | alltypes_tiny_pages.parquet | 454233 | 267014 | 2 | page_index=true | - | lz4_raw_compressed_larger.parquet | 380836 | 996 | 3 | page_index=false | + | alltypes_plain.parquet | 1851 | 8882 | 5 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269266 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1347 | 3 | page_index=false | +-----------------------------------+-----------------+---------------------+------+------------------+ "); Ok(()) } + + /// Shows that the statistics cache is not enabled by default yet + /// See https://github.com/apache/datafusion/issues/19217 + #[tokio::test] + async fn test_statistics_cache_default() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + // When the cache manager creates a StatisticsCache by default, + // the contents will show up here + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + ++ + ++ + "); + + Ok(()) + } + + // Can be removed when https://github.com/apache/datafusion/issues/19217 is resolved + #[tokio::test] + async fn test_statistics_cache_override() -> Result<(), DataFusionError> { + // Install a specific StatisticsCache implementation + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build()?; + let config = SessionConfig::new().with_collect_statistics(true); + let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | filename | file_size_bytes | num_rows | num_columns | table_size_bytes | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | alltypes_plain.parquet | 1851 | Exact(8) | 11 | Absent | + | alltypes_tiny_pages.parquet | 454233 | Exact(7300) | 13 | Absent | + | lz4_raw_compressed_larger.parquet | 380836 | Exact(10000) | 1 | Absent | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + "); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index e6e6be42c7ad0..3cee78a5b17cc 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -20,7 +20,7 @@ pub mod instrumented; use async_trait::async_trait; use aws_config::BehaviorVersion; use aws_credential_types::provider::{ - error::CredentialsError, ProvideCredentials, SharedCredentialsProvider, + ProvideCredentials, SharedCredentialsProvider, error::CredentialsError, }; use datafusion::{ common::{ @@ -33,12 +33,12 @@ use datafusion::{ }; use log::debug; use object_store::{ - aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, - gcp::GoogleCloudStorageBuilder, - http::HttpBuilder, ClientOptions, CredentialProvider, Error::Generic, ObjectStore, + aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, + gcp::GoogleCloudStorageBuilder, + http::HttpBuilder, }; use std::{ any::Any, @@ -124,14 +124,15 @@ pub async fn get_s3_object_store_builder( if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" - if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { - if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { - return config_err!( - "Invalid endpoint: {endpoint}. \ + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) + && !matches!(allow_http, Some(true)) + && endpoint_url.scheme() == "http" + { + return config_err!( + "Invalid endpoint: {endpoint}. \ HTTP is not allowed for S3 endpoints. \ To allow HTTP, set 'aws.allow_http' to true" - ); - } + ); } builder = builder.with_endpoint(endpoint); @@ -586,8 +587,10 @@ mod tests { let location = "s3://bucket/path/FAKE/file.parquet"; // Set it to a non-existent file to avoid reading the default configuration file - std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); - std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + } // No options let table_url = ListingTableUrl::parse(location)?; @@ -716,7 +719,10 @@ mod tests { .await .unwrap_err(); - assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); + assert_eq!( + err.to_string().lines().next().unwrap_or_default(), + "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true" + ); // Now add `allow_http` to the options and check if it works let sql = format!( @@ -746,7 +752,9 @@ mod tests { let expected_region = "eu-central-1"; let location = "s3://test-bucket/path/file.parquet"; // Set it to a non-existent file to avoid reading the default configuration file - std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + } let table_url = ListingTableUrl::parse(location)?; let aws_options = AwsOptions { @@ -767,8 +775,8 @@ mod tests { } #[tokio::test] - async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled( - ) -> Result<()> { + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled() + -> Result<()> { if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { // Skip test if AWS envs are not set eprintln!("{e}"); @@ -806,7 +814,9 @@ mod tests { let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); ctx.register_table_options_extension_from_scheme(scheme); @@ -830,14 +840,15 @@ mod tests { #[tokio::test] async fn gcs_object_store_builder() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); ctx.register_table_options_extension_from_scheme(scheme); diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs index c4b63b417fe42..0d5e9dc2c5a84 100644 --- a/datafusion-cli/src/object_storage/instrumented.rs +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -20,8 +20,8 @@ use std::{ ops::AddAssign, str::FromStr, sync::{ - atomic::{AtomicU8, Ordering}, Arc, + atomic::{AtomicU8, Ordering}, }, time::Duration, }; @@ -31,18 +31,71 @@ use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use chrono::Utc; use datafusion::{ - common::{instant::Instant, HashMap}, + common::{HashMap, instant::Instant}, error::DataFusionError, execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; -use futures::stream::BoxStream; +use futures::stream::{BoxStream, Stream}; use object_store::{ - path::Path, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, + GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, + path::Path, }; use parking_lot::{Mutex, RwLock}; use url::Url; +/// A stream wrapper that measures the time until the first response(item or end of stream) is yielded +struct TimeToFirstItemStream { + inner: S, + start: Instant, + request_index: usize, + requests: Arc>>, + first_item_yielded: bool, +} + +impl TimeToFirstItemStream { + fn new( + inner: S, + start: Instant, + request_index: usize, + requests: Arc>>, + ) -> Self { + Self { + inner, + start, + request_index, + requests, + first_item_yielded: false, + } + } +} + +impl Stream for TimeToFirstItemStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let poll_result = std::pin::Pin::new(&mut self.inner).poll_next(cx); + + if !self.first_item_yielded && poll_result.is_ready() { + self.first_item_yielded = true; + let elapsed = self.start.elapsed(); + + let mut requests = self.requests.lock(); + if let Some(request) = requests.get_mut(self.request_index) { + request.duration = Some(elapsed); + } + } + + poll_result + } +} + /// The profiling mode to use for an [`InstrumentedObjectStore`] instance. Collecting profiling /// data will have a small negative impact on both CPU and memory usage. Default is `Disabled` #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] @@ -91,7 +144,7 @@ impl From for InstrumentedObjectStoreMode { pub struct InstrumentedObjectStore { inner: Arc, instrument_mode: AtomicU8, - requests: Mutex>, + requests: Arc>>, } impl InstrumentedObjectStore { @@ -100,7 +153,7 @@ impl InstrumentedObjectStore { Self { inner: object_store, instrument_mode, - requests: Mutex::new(Vec::new()), + requests: Arc::new(Mutex::new(Vec::new())), } } @@ -218,19 +271,31 @@ impl InstrumentedObjectStore { prefix: Option<&Path>, ) -> BoxStream<'static, Result> { let timestamp = Utc::now(); - let ret = self.inner.list(prefix); + let start = Instant::now(); + let inner_stream = self.inner.list(prefix); + + let request_index = { + let mut requests = self.requests.lock(); + requests.push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration: None, + size: None, + range: None, + extra_display: None, + }); + requests.len() - 1 + }; - self.requests.lock().push(RequestDetails { - op: Operation::List, - path: prefix.cloned().unwrap_or_else(|| Path::from("")), - timestamp, - duration: None, // list returns a stream, so the duration isn't meaningful - size: None, - range: None, - extra_display: None, - }); + let wrapped_stream = TimeToFirstItemStream::new( + inner_stream, + start, + request_index, + Arc::clone(&self.requests), + ); - ret + Box::pin(wrapped_stream) } async fn instrumented_list_with_delimiter( @@ -758,6 +823,7 @@ impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { #[cfg(test)] mod tests { + use futures::StreamExt; use object_store::WriteMultipart; use super::*; @@ -782,9 +848,11 @@ mod tests { "TRaCe".parse().unwrap(), InstrumentedObjectStoreMode::Trace )); - assert!("does_not_exist" - .parse::() - .is_err()); + assert!( + "does_not_exist" + .parse::() + .is_err() + ); assert!(matches!(0.into(), InstrumentedObjectStoreMode::Disabled)); assert!(matches!(1.into(), InstrumentedObjectStoreMode::Summary)); @@ -896,13 +964,15 @@ mod tests { instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); assert!(instrumented.requests.lock().is_empty()); - let _ = instrumented.list(Some(&path)); + let mut stream = instrumented.list(Some(&path)); + // Consume at least one item from the stream to trigger duration measurement + let _ = stream.next().await; assert_eq!(instrumented.requests.lock().len(), 1); let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::List); assert_eq!(request.path, path); - assert!(request.duration.is_none()); + assert!(request.duration.is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none()); diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 56bdb15a315d9..cfb8a32ffcfeb 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -247,12 +247,12 @@ mod tests { .with_schema(three_column_schema()) .with_batches(vec![]) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+---+---+ | a | b | c | +---+---+---+ +---+---+---+ - "#); + "); } #[test] @@ -262,11 +262,11 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" 1,4,7 2,5,8 3,6,9 - "#); + "); } #[test] @@ -276,12 +276,12 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" a,b,c 1,4,7 2,5,8 3,6,9 - "#); + "); } #[test] @@ -291,10 +291,10 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) .run(); - assert_snapshot!(output, @" - 1\t4\t7 - 2\t5\t8 - 3\t6\t9 + assert_snapshot!(output, @r" + 1 4 7 + 2 5 8 + 3 6 9 ") } @@ -305,11 +305,11 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) .run(); - assert_snapshot!(output, @" - a\tb\tc - 1\t4\t7 - 2\t5\t8 - 3\t6\t9 + assert_snapshot!(output, @r" + a b c + 1 4 7 + 2 5 8 + 3 6 9 "); } @@ -320,7 +320,7 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Ignored) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+---+---+ | a | b | c | +---+---+---+ @@ -328,7 +328,7 @@ mod tests { | 2 | 5 | 8 | | 3 | 6 | 9 | +---+---+---+ - "#); + "); } #[test] fn print_json() { @@ -337,9 +337,7 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Ignored) .run(); - assert_snapshot!(output, @r#" - [{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}] - "#); + assert_snapshot!(output, @r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#); } #[test] @@ -363,11 +361,11 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" 1,4,7 2,5,8 3,6,9 - "#); + "); } #[test] fn print_automatic_with_header() { @@ -376,12 +374,12 @@ mod tests { .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" a,b,c 1,4,7 2,5,8 3,6,9 - "#); + "); } #[test] @@ -396,7 +394,7 @@ mod tests { .with_maxrows(max_rows) .run(); allow_duplicates! { - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+ | a | +---+ @@ -404,7 +402,7 @@ mod tests { | 2 | | 3 | +---+ - "#); + "); } } } @@ -416,7 +414,7 @@ mod tests { .with_batches(vec![one_column_batch()]) .with_maxrows(MaxRows::Limited(1)) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+ | a | +---+ @@ -425,7 +423,7 @@ mod tests { | . | | . | +---+ - "#); + "); } #[test] @@ -439,7 +437,7 @@ mod tests { ]) .with_maxrows(MaxRows::Limited(5)) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+ | a | +---+ @@ -452,7 +450,7 @@ mod tests { | . | | . | +---+ - "#); + "); } #[test] @@ -464,7 +462,7 @@ mod tests { .with_format(PrintFormat::Table) .with_batches(vec![empty_batch.clone(), batch, empty_batch]) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+ | a | +---+ @@ -472,7 +470,7 @@ mod tests { | 2 | | 3 | +---+ - "#); + "); } #[test] @@ -486,12 +484,12 @@ mod tests { .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) .run(); - assert_snapshot!(output, @r#" + assert_snapshot!(output, @r" +---+ | a | +---+ +---+ - "#); + "); // No output for empty batch when schema contains no columns let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 93d1d450fd82b..5fbe27d805db0 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -28,8 +28,8 @@ use crate::print_format::PrintFormat; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion::common::instant::Instant; use datafusion::common::DataFusionError; +use datafusion::common::instant::Instant; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; @@ -55,8 +55,10 @@ impl FromStr for MaxRows { Ok(Self::Unlimited) } else { match maxrows.parse::() { - Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), + Ok(nrows) => Ok(Self::Limited(nrows)), + _ => Err(format!( + "Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit." + )), } } } diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index c1395aa4f562c..d6f8deedfe32c 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -20,7 +20,7 @@ use std::process::Command; use rstest::rstest; use async_trait::async_trait; -use insta::{glob, Settings}; +use insta::{Settings, glob}; use insta_cmd::{assert_cmd_snapshot, get_cargo_bin}; use std::path::PathBuf; use std::{env, fs}; @@ -111,7 +111,9 @@ async fn setup_minio_container() -> ContainerAsync { } Err(TestcontainersError::Client(e)) => { - panic!("Failed to start MinIO container. Ensure Docker is running and accessible: {e}"); + panic!( + "Failed to start MinIO container. Ensure Docker is running and accessible: {e}" + ); } Err(e) => { panic!("Failed to start MinIO container: {e}"); @@ -258,13 +260,15 @@ async fn test_cli() { glob!("sql/integration/*.sql", |path| { let input = fs::read_to_string(path).unwrap(); - assert_cmd_snapshot!(cli() - .env_clear() - .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") - .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") - .env("AWS_ENDPOINT", format!("http://localhost:{port}")) - .env("AWS_ALLOW_HTTP", "true") - .pass_stdin(input)) + assert_cmd_snapshot!( + cli() + .env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + .pass_stdin(input) + ) }); } @@ -328,10 +332,12 @@ SELECT COUNT(*) FROM hits; "# ); - assert_cmd_snapshot!(cli() - .env("RUST_LOG", "warn") - .env_remove("AWS_ENDPOINT") - .pass_stdin(input)); + assert_cmd_snapshot!( + cli() + .env("RUST_LOG", "warn") + .env_remove("AWS_ENDPOINT") + .pass_stdin(input) + ); } /// Ensure backtrace will be printed, if executing `datafusion-cli` with a query @@ -450,7 +456,7 @@ SELECT * from CARS LIMIT 1; #[async_trait] trait MinioCommandExt { async fn with_minio(&mut self, container: &ContainerAsync) - -> &mut Self; + -> &mut Self; } #[async_trait] diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 6b3a247dd7b82..1359cefbe71c7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -7,7 +7,6 @@ info: - EXPLAIN SELECT 123 env: DATAFUSION_EXPLAIN_FORMAT: pgjson -snapshot_kind: text --- success: true exit_code: 0 diff --git a/datafusion-cli/tests/snapshots/cli_format@automatic.snap b/datafusion-cli/tests/snapshots/cli_format@automatic.snap index 2591f493e90a8..76b14d9a3a924 100644 --- a/datafusion-cli/tests/snapshots/cli_format@automatic.snap +++ b/datafusion-cli/tests/snapshots/cli_format@automatic.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@csv.snap b/datafusion-cli/tests/snapshots/cli_format@csv.snap index c41b042298eb0..2c969bd91d121 100644 --- a/datafusion-cli/tests/snapshots/cli_format@csv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@csv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@json.snap b/datafusion-cli/tests/snapshots/cli_format@json.snap index 8f804a337cce5..22a9cc4657a91 100644 --- a/datafusion-cli/tests/snapshots/cli_format@json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap index 7b4ce1e2530cf..513bcb7372ca6 100644 --- a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@table.snap b/datafusion-cli/tests/snapshots/cli_format@table.snap index 99914182462aa..8677847588385 100644 --- a/datafusion-cli/tests/snapshots/cli_format@table.snap +++ b/datafusion-cli/tests/snapshots/cli_format@table.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@tsv.snap b/datafusion-cli/tests/snapshots/cli_format@tsv.snap index 968268c31dd55..c56e60fcab155 100644 --- a/datafusion-cli/tests/snapshots/cli_format@tsv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@tsv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap index c27d527df0b6a..9fd07fa6f4e1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap index 46ee6be64f624..8620f6da84488 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap @@ -5,7 +5,6 @@ info: args: - "--command" - EXPLAIN SELECT 123 -snapshot_kind: text --- success: true exit_code: 0 diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap index 7c44e41729a17..df3a10b6bb54b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap index 3b975bb6a927d..a394458768d1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 38f1f8b0e0cad..b0190dadf3c3f 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -35,18 +35,6 @@ rust-version = { workspace = true } [lints] workspace = true -[[example]] -name = "dataframe_to_s3" -path = "examples/external_dependency/dataframe-to-s3.rs" - -[[example]] -name = "query_aws_s3" -path = "examples/external_dependency/query-aws-s3.rs" - -[[example]] -name = "custom_file_casts" -path = "examples/custom_file_casts.rs" - [dev-dependencies] arrow = { workspace = true } # arrow_schema is required for record_batch! macro :sad: @@ -58,17 +46,22 @@ dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } -datafusion-ffi = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } +datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } +insta = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } serde_json = { workspace = true } +strum = { workspace = true } +strum_macros = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } @@ -76,7 +69,7 @@ tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.18" +uuid = "1.19" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] nix = { version = "0.30.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 1befba6be66fd..8f38b38990363 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -39,60 +39,178 @@ git submodule update --init # Change to the examples directory cd datafusion-examples/examples -# Run the `dataframe` example: -# ... use the equivalent for other examples -cargo run --example dataframe +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe ``` -## Single Process - -- [`examples/udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) -- [`examples/udf/advanced_udf.rs`](examples/udf/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) -- [`examples/udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) -- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files -- [`examples/udf/async_udf.rs`](examples/udf/async_udf.rs): Define and invoke an asynchronous User Defined Scalar Function (UDF) -- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) -- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog -- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization -- [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file -- [`csv_json_opener.rs`](examples/csv_json_opener.rs): Use low level `FileOpener` APIs to read CSV/JSON into Arrow `RecordBatch`es -- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) -- [`custom_file_casts.rs`](examples/custom_file_casts.rs): Implement custom casting rules to adapt file schemas -- [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. -- [`examples/builtin_functions/date_time`](examples/builtin_functions/date_time.rs): Examples of date-time related functions and queries -- [`default_column_values.rs`](examples/default_column_values.rs): Implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter -- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results (Arrow ArrayRefs) into Rust structs -- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s -- [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. -- [`flight/sql_server.rs`](examples/flight/sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from Flight and and FlightSQL (e.g. JDBC) clients -- [`examples/builtin_functions/function_factory.rs`](examples/builtin_functions/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros -- [`memory_pool_tracking.rs`](examples/memory_pool_tracking.rs): Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages -- [`memory_pool_execution_plan.rs`](examples/memory_pool_execution_plan.rs): Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling -- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates -- [`parquet_embedded_index.rs`](examples/parquet_embedded_index.rs): Store a custom index inside a Parquet file and use it to speed up queries -- [`parquet_encrypted.rs`](examples/parquet_encrypted.rs): Read and write encrypted Parquet files using DataFusion -- [`parquet_encrypted_with_kms.rs`](examples/parquet_encrypted_with_kms.rs): Read and write encrypted Parquet files using an encryption factory -- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries -- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution -- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. -- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` -- [`planner_api.rs`](examples/planner_api.rs) APIs to manipulate logical and physical plans -- [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics -- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 -- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP -- [`examples/builtin_functions/regexp.rs`](examples/builtin_functions/regexp.rs): Examples of using regular expression functions -- [`remote_catalog.rs`](examples/regexp.rs): Examples of interfacing with a remote catalog (e.g. over a network) -- [`examples/udf/simple_udaf.rs`](examples/udf/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`examples/udf/simple_udf.rs`](examples/udf/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) -- [`examples/udf/simple_udtf.rs`](examples/udf/simple_udtf.rs): Define and invoke a User Defined Table Function (UDTF) -- [`examples/udf/simple_udfw.rs`](examples/udf/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures -- [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings -- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` -- [`sql_query.rs`](examples/memtable.rs): Query data using SQL (in memory `RecordBatches`, local Parquet files) - -## Distributed - -- [`examples/flight/client.rs`](examples/flight/client.rs) and [`examples/flight/server.rs`](examples/flight/server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Arrow Flight protocol. +## Builtin Functions Examples + +### Group: `builtin_functions` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------- | +| date_time | [`builtin_functions/date_time.rs`](examples/builtin_functions/date_time.rs) | Examples of date-time related functions and queries | +| function_factory | [`builtin_functions/function_factory.rs`](examples/builtin_functions/function_factory.rs) | Register `CREATE FUNCTION` handler to implement SQL macros | +| regexp | [`builtin_functions/regexp.rs`](examples/builtin_functions/regexp.rs) | Examples of using regular expression functions | + +## Custom Data Source Examples + +### Group: `custom_data_source` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------------------------- | --------------------------------------------- | +| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | +| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | +| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | +| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | +| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | +| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | +| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | + +## Data IO Examples + +### Group: `data_io` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| -------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------ | +| catalog | [`data_io/catalog.rs`](examples/data_io/catalog.rs) | Register tables into a custom catalog | +| json_shredding | [`data_io/json_shredding.rs`](examples/data_io/json_shredding.rs) | Implement filter rewriting for JSON shredding | +| parquet_adv_idx | [`data_io/parquet_advanced_index.rs`](examples/data_io/parquet_advanced_index.rs) | Create a secondary index across multiple parquet files | +| parquet_emb_idx | [`data_io/parquet_embedded_index.rs`](examples/data_io/parquet_embedded_index.rs) | Store a custom index inside Parquet files | +| parquet_enc | [`data_io/parquet_encrypted.rs`](examples/data_io/parquet_encrypted.rs) | Read & write encrypted Parquet files | +| parquet_enc_with_kms | [`data_io/parquet_encrypted_with_kms.rs`](examples/data_io/parquet_encrypted_with_kms.rs) | Encrypted Parquet I/O using a KMS-backed factory | +| parquet_exec_visitor | [`data_io/parquet_exec_visitor.rs`](examples/data_io/parquet_exec_visitor.rs) | Extract statistics by visiting an ExecutionPlan | +| parquet_idx | [`data_io/parquet_index.rs`](examples/data_io/parquet_index.rs) | Create a secondary index | +| query_http_csv | [`data_io/query_http_csv.rs`](examples/data_io/query_http_csv.rs) | Query CSV files via HTTP | +| remote_catalog | [`data_io/remote_catalog.rs`](examples/data_io/remote_catalog.rs) | Interact with a remote catalog | + +## DataFrame Examples + +### Group: `dataframe` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------ | +| dataframe | [`dataframe/dataframe.rs`](examples/dataframe/dataframe.rs) | Query DataFrames from various sources and write output | +| deserialize_to_struct | [`dataframe/deserialize_to_struct.rs`](examples/dataframe/deserialize_to_struct.rs) | Convert Arrow arrays into Rust structs | + +## Execution Monitoring Examples + +### Group: `execution_monitoring` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------ | ------------------------------------------------------------------------------------------------------------------- | ---------------------------------------- | +| mem_pool_exec_plan | [`execution_monitoring/memory_pool_execution_plan.rs`](examples/execution_monitoring/memory_pool_execution_plan.rs) | Memory-aware ExecutionPlan with spilling | +| mem_pool_tracking | [`execution_monitoring/memory_pool_tracking.rs`](examples/execution_monitoring/memory_pool_tracking.rs) | Demonstrates memory tracking | +| tracing | [`execution_monitoring/tracing.rs`](examples/execution_monitoring/tracing.rs) | Demonstrates tracing integration | + +## External Dependency Examples + +### Group: `external_dependency` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------------- | ---------------------------------------- | +| dataframe_to_s3 | [`external_dependency/dataframe_to_s3.rs`](examples/external_dependency/dataframe_to_s3.rs) | Query DataFrames and write results to S3 | +| query_aws_s3 | [`external_dependency/query_aws_s3.rs`](examples/external_dependency/query_aws_s3.rs) | Query S3-backed data using object_store | + +## Flight Examples + +### Group: `flight` + +#### Category: Distributed + +| Subcommand | File Path | Description | +| ---------- | ------------------------------------------------------- | ------------------------------------------------------ | +| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | +| client | [`flight/client.rs`](examples/flight/client.rs) | Execute SQL queries via Arrow Flight protocol | +| sql_server | [`flight/sql_server.rs`](examples/flight/sql_server.rs) | Standalone SQL server for JDBC clients | + +## Proto Examples + +### Group: `proto` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------------ | --------------------------------------------------------------------------------- | --------------------------------------------------------------- | +| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | + +## Query Planning Examples + +### Group: `query_planning` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| -------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------ | +| analyzer_rule | [`query_planning/analyzer_rule.rs`](examples/query_planning/analyzer_rule.rs) | Custom AnalyzerRule to change query semantics | +| expr_api | [`query_planning/expr_api.rs`](examples/query_planning/expr_api.rs) | Create, execute, analyze, and coerce Exprs | +| optimizer_rule | [`query_planning/optimizer_rule.rs`](examples/query_planning/optimizer_rule.rs) | Replace predicates via a custom OptimizerRule | +| parse_sql_expr | [`query_planning/parse_sql_expr.rs`](examples/query_planning/parse_sql_expr.rs) | Parse SQL into DataFusion Expr | +| plan_to_sql | [`query_planning/plan_to_sql.rs`](examples/query_planning/plan_to_sql.rs) | Generate SQL from expressions or plans | +| planner_api | [`query_planning/planner_api.rs`](examples/query_planning/planner_api.rs) | APIs for logical and physical plan manipulation | +| pruning | [`query_planning/pruning.rs`](examples/query_planning/pruning.rs) | Use pruning to skip irrelevant files | +| thread_pools | [`query_planning/thread_pools.rs`](examples/query_planning/thread_pools.rs) | Configure custom thread pools for DataFusion execution | + +## Relation Planner Examples + +### Group: `relation_planner` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------- | ------------------------------------------ | +| match_recognize | [`relation_planner/match_recognize.rs`](examples/relation_planner/match_recognize.rs) | Implement MATCH_RECOGNIZE pattern matching | +| pivot_unpivot | [`relation_planner/pivot_unpivot.rs`](examples/relation_planner/pivot_unpivot.rs) | Implement PIVOT / UNPIVOT | +| table_sample | [`relation_planner/table_sample.rs`](examples/relation_planner/table_sample.rs) | Implement TABLESAMPLE | + +## SQL Ops Examples + +### Group: `sql_ops` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ----------------- | ----------------------------------------------------------------------- | -------------------------------------------------- | +| analysis | [`sql_ops/analysis.rs`](examples/sql_ops/analysis.rs) | Analyze SQL queries | +| custom_sql_parser | [`sql_ops/custom_sql_parser.rs`](examples/sql_ops/custom_sql_parser.rs) | Implement a custom SQL parser to extend DataFusion | +| frontend | [`sql_ops/frontend.rs`](examples/sql_ops/frontend.rs) | Build LogicalPlans from SQL | +| query | [`sql_ops/query.rs`](examples/sql_ops/query.rs) | Query data using SQL | + +## UDF Examples + +### Group: `udf` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------- | ------------------------------------------------------- | ----------------------------------------------- | +| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | +| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | +| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | +| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | +| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | +| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | +| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | +| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | diff --git a/datafusion-examples/examples/builtin_functions/date_time.rs b/datafusion-examples/examples/builtin_functions/date_time.rs index 178cba979cb95..08d4bc6e29978 100644 --- a/datafusion-examples/examples/builtin_functions/date_time.rs +++ b/datafusion-examples/examples/builtin_functions/date_time.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::array::{Date32Array, Int32Array}; @@ -179,12 +181,13 @@ async fn query_make_date() -> Result<()> { // invalid column values will result in an error let result = ctx - .sql("select make_date(2024, null, 23)") + .sql("select make_date(2024, '', 23)") .await? .collect() .await; - let expected = "Execution error: Unable to parse date from null/empty value"; + let expected = + "Arrow error: Cast error: Cannot cast string '' to value of Int32 type"; assert_contains!(result.unwrap_err().to_string(), expected); // invalid date values will also result in an error @@ -194,7 +197,7 @@ async fn query_make_date() -> Result<()> { .collect() .await; - let expected = "Execution error: Unable to parse date from 2024, 1, 32"; + let expected = "Execution error: Day value '32' is out of range"; assert_contains!(result.unwrap_err().to_string(), expected); Ok(()) diff --git a/datafusion-examples/examples/builtin_functions/function_factory.rs b/datafusion-examples/examples/builtin_functions/function_factory.rs index 5d41e7a260713..7eff0d0b5c484 100644 --- a/datafusion-examples/examples/builtin_functions/function_factory.rs +++ b/datafusion-examples/examples/builtin_functions/function_factory.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{exec_datafusion_err, exec_err, internal_err, DataFusionError}; +use datafusion::common::{DataFusionError, exec_datafusion_err, exec_err, internal_err}; use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, diff --git a/datafusion-examples/examples/builtin_functions/main.rs b/datafusion-examples/examples/builtin_functions/main.rs index 3399c395bfd62..f9e0a44a09a34 100644 --- a/datafusion-examples/examples/builtin_functions/main.rs +++ b/datafusion-examples/examples/builtin_functions/main.rs @@ -19,7 +19,13 @@ //! //! These examples demonstrate miscellaneous function-related features. //! +//! ## Usage +//! ```bash +//! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] +//! ``` +//! //! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module //! - `date_time` — examples of date-time related functions and queries //! - `function_factory` — register `CREATE FUNCTION` handler to implement SQL macros //! - `regexp` — examples of using regular expression functions @@ -28,46 +34,39 @@ mod date_time; mod function_factory; mod regexp; -use std::str::FromStr; - use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] enum ExampleKind { + All, DateTime, FunctionFactory, Regexp, } -impl AsRef for ExampleKind { - fn as_ref(&self) -> &str { - match self { - Self::DateTime => "date_time", - Self::FunctionFactory => "function_factory", - Self::Regexp => "regexp", - } - } -} - -impl FromStr for ExampleKind { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - match s { - "date_time" => Ok(Self::DateTime), - "function_factory" => Ok(Self::FunctionFactory), - "regexp" => Ok(Self::Regexp), - _ => Err(DataFusionError::Execution(format!("Unknown example: {s}"))), - } - } -} - impl ExampleKind { - const ALL: [Self; 3] = [Self::DateTime, Self::FunctionFactory, Self::Regexp]; - const EXAMPLE_NAME: &str = "builtin_functions"; - fn variants() -> Vec<&'static str> { - Self::ALL.iter().map(|x| x.as_ref()).collect() + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DateTime => date_time::date_time().await?, + ExampleKind::FunctionFactory => function_factory::function_factory().await?, + ExampleKind::Regexp => regexp::regexp().await?, + } + Ok(()) } } @@ -76,19 +75,14 @@ async fn main() -> Result<()> { let usage = format!( "Usage: cargo run --example {} -- [{}]", ExampleKind::EXAMPLE_NAME, - ExampleKind::variants().join("|") + ExampleKind::VARIANTS.join("|") ); - let arg = std::env::args().nth(1).ok_or_else(|| { - eprintln!("{usage}"); - DataFusionError::Execution("Missing argument".to_string()) - })?; - - match arg.parse::()? { - ExampleKind::DateTime => date_time::date_time().await?, - ExampleKind::FunctionFactory => function_factory::function_factory().await?, - ExampleKind::Regexp => regexp::regexp().await?, - } + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; - Ok(()) + example.run().await } diff --git a/datafusion-examples/examples/builtin_functions/regexp.rs b/datafusion-examples/examples/builtin_functions/regexp.rs index 13c0786930283..e8376cd0c94eb 100644 --- a/datafusion-examples/examples/builtin_functions/regexp.rs +++ b/datafusion-examples/examples/builtin_functions/regexp.rs @@ -16,9 +16,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::{fs::File, io::Write}; + use datafusion::common::{assert_batches_eq, assert_contains}; use datafusion::error::Result; use datafusion::prelude::*; +use tempfile::tempdir; /// This example demonstrates how to use the regexp_* functions /// @@ -30,12 +35,30 @@ use datafusion::prelude::*; /// https://docs.rs/regex/latest/regex/#grouping-and-flags pub async fn regexp() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv( - "examples", - "datafusion/physical-expr/tests/data/regex.csv", - CsvReadOptions::new(), - ) - .await?; + // content from file 'datafusion/physical-expr/tests/data/regex.csv' + let csv_data = r#"values,patterns,replacement,flags +abc,^(a),bb\1bb,i +ABC,^(A).*,B,i +aBc,(b|d),e,i +AbC,(B|D),e, +aBC,^(b|c),d, +4000,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +4010,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +Düsseldorf,[\p{Letter}-]+,München, +Москва,[\p{L}-]+,Moscow, +Köln,[a-zA-Z]ö[a-zA-Z]{2},Koln, +اليوم,^\p{Arabic}+$,Today,"#; + let dir = tempdir()?; + let file_path = dir.path().join("regex.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all(csv_data.as_bytes())?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); + + ctx.register_csv("examples", file_path, CsvReadOptions::new()) + .await?; // // @@ -111,11 +134,11 @@ pub async fn regexp() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------+----------------------------------------------------+", - "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+---------------------------------------------------+----------------------------------------------------+", - "| true | true |", - "+---------------------------------------------------+----------------------------------------------------+", + "+---------------------------------------------------+----------------------------------------------------+", + "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+---------------------------------------------------+----------------------------------------------------+", + "| true | true |", + "+---------------------------------------------------+----------------------------------------------------+", ], &result ); @@ -241,11 +264,11 @@ pub async fn regexp() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------------------+-----------------------------------------------------+", - "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+----------------------------------------------------+-----------------------------------------------------+", - "| [John Smith] | [Smith Jones] |", - "+----------------------------------------------------+-----------------------------------------------------+", + "+----------------------------------------------------+-----------------------------------------------------+", + "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+----------------------------------------------------+-----------------------------------------------------+", + "| [John Smith] | [Smith Jones] |", + "+----------------------------------------------------+-----------------------------------------------------+", ], &result ); @@ -267,21 +290,21 @@ pub async fn regexp() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------------------------------------------------------------+", - "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", - "+---------------------------------------------------------------------------------------------------------+", - "| bbabbbc |", - "| B |", - "| aec |", - "| AbC |", - "| aBC |", - "| 4000 |", - "| xyz |", - "| München |", - "| Moscow |", - "| Koln |", - "| Today |", - "+---------------------------------------------------------------------------------------------------------+", + "+---------------------------------------------------------------------------------------------------------+", + "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", + "+---------------------------------------------------------------------------------------------------------+", + "| bbabbbc |", + "| B |", + "| aec |", + "| AbC |", + "| aBC |", + "| 4000 |", + "| xyz |", + "| München |", + "| Moscow |", + "| Koln |", + "| Today |", + "+---------------------------------------------------------------------------------------------------------+", ], &result ); @@ -295,11 +318,11 @@ pub async fn regexp() -> Result<()> { assert_batches_eq!( &[ - "+------------------------------------------------------------------------+", - "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", - "+------------------------------------------------------------------------+", - "| fooXarYXazY |", - "+------------------------------------------------------------------------+", + "+------------------------------------------------------------------------+", + "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", + "+------------------------------------------------------------------------+", + "| fooXarYXazY |", + "+------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion-examples/examples/csv_json_opener.rs b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs similarity index 80% rename from datafusion-examples/examples/csv_json_opener.rs rename to datafusion-examples/examples/custom_data_source/csv_json_opener.rs index ef2a3eaca0c88..7b2e321362632 100644 --- a/datafusion-examples/examples/csv_json_opener.rs +++ b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::config::CsvOptions; use datafusion::{ assert_batches_eq, datasource::{ @@ -31,18 +34,15 @@ use datafusion::{ test_util::aggr_test_schema, }; -use datafusion::datasource::{ - physical_plan::FileScanConfigBuilder, table_schema::TableSchema, -}; +use datafusion::datasource::physical_plan::FileScanConfigBuilder; use futures::StreamExt; -use object_store::{local::LocalFileSystem, memory::InMemory, ObjectStore}; +use object_store::{ObjectStore, local::LocalFileSystem, memory::InMemory}; /// This example demonstrates using the low level [`FileStream`] / [`FileOpener`] APIs to directly /// read data from (CSV/JSON) into Arrow RecordBatches. /// /// If you want to query data in CSV or JSON files, see the [`dataframe.rs`] and [`sql_query.rs`] examples -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_json_opener() -> Result<()> { csv_opener().await?; json_opener().await?; Ok(()) @@ -57,23 +57,29 @@ async fn csv_opener() -> Result<()> { let path = std::path::Path::new(&path).canonicalize()?; - let scan_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - Arc::clone(&schema), - Arc::new(CsvSource::default()), - ) - .with_projection_indices(Some(vec![12, 0])) - .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.display().to_string(), 10)) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; - let config = CsvSource::new(true, b',', b'"') + let source = CsvSource::new(Arc::clone(&schema)) + .with_csv_options(options) .with_comment(Some(b'#')) - .with_schema(TableSchema::from_file_schema(schema)) - .with_batch_size(8192) - .with_projection(&scan_config); + .with_batch_size(8192); + + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_projection_indices(Some(vec![12, 0]))? + .with_limit(Some(5)) + .with_file(PartitionedFile::new(path.display().to_string(), 10)) + .build(); - let opener = config.create_file_opener(object_store, &scan_config, 0); + let opener = + scan_config + .file_source() + .create_file_opener(object_store, &scan_config, 0)?; let mut result = vec![]; let mut stream = @@ -125,10 +131,9 @@ async fn json_opener() -> Result<()> { let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - schema, - Arc::new(JsonSource::default()), + Arc::new(JsonSource::new(schema)), ) - .with_projection_indices(Some(vec![1, 0])) + .with_projection_indices(Some(vec![1, 0]))? .with_limit(Some(5)) .with_file(PartitionedFile::new(path.to_string(), 10)) .build(); diff --git a/datafusion-examples/examples/csv_sql_streaming.rs b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs similarity index 96% rename from datafusion-examples/examples/csv_sql_streaming.rs rename to datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs index 99264bbcb486d..554382ea9549e 100644 --- a/datafusion-examples/examples/csv_sql_streaming.rs +++ b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::common::test_util::datafusion_test_data; use datafusion::error::Result; use datafusion::prelude::*; /// This example demonstrates executing a simple query against an Arrow data source (CSV) and /// fetching results with streaming aggregation and streaming window -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_sql_streaming() -> Result<()> { // create local execution context let ctx = SessionContext::new(); diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_data_source/custom_datasource.rs similarity index 95% rename from datafusion-examples/examples/custom_datasource.rs rename to datafusion-examples/examples/custom_data_source/custom_datasource.rs index bc865fac5a338..b276ae32cf247 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_data_source/custom_datasource.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Debug, Formatter}; @@ -22,10 +24,10 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use async_trait::async_trait; -use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +use datafusion::arrow::array::{UInt8Builder, UInt64Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::{provider_as_source, TableProvider, TableType}; +use datafusion::datasource::{TableProvider, TableType, provider_as_source}; use datafusion::error::Result; use datafusion::execution::context::TaskContext; use datafusion::logical_expr::LogicalPlanBuilder; @@ -33,8 +35,8 @@ use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ - project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, project_schema, }; use datafusion::prelude::*; @@ -42,8 +44,7 @@ use datafusion::catalog::Session; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource -#[tokio::main] -async fn main() -> Result<()> { +pub async fn custom_datasource() -> Result<()> { // create our custom datasource and adding some users let db = CustomDataSource::default(); db.populate_users(); @@ -195,6 +196,7 @@ struct CustomExec { } impl CustomExec { + #[expect(clippy::needless_pass_by_value)] fn new( projections: Option<&Vec>, schema: SchemaRef, diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs similarity index 89% rename from datafusion-examples/examples/custom_file_casts.rs rename to datafusion-examples/examples/custom_data_source/custom_file_casts.rs index 4d97ecd91dc64..895b6f52b6e1e 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -15,23 +15,25 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; -use arrow::array::{record_batch, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::array::{RecordBatch, record_batch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::assert_batches_eq; +use datafusion::common::Result; use datafusion::common::not_impl_err; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; -use datafusion::physical_expr::expressions::CastExpr; use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::expressions::{CastColumnExpr, CastExpr}; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -44,9 +46,7 @@ use object_store::{ObjectStore, PutPayload}; // This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error // before even reading the data. // Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn custom_file_casts() -> Result<()> { println!("=== Creating example data ==="); // Create a logical / table schema with an Int32 column @@ -192,18 +192,21 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { ); } } + if let Some(cast) = expr.as_any().downcast_ref::() { + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; + if !CastExpr::check_bigger_cast( + cast.target_field().data_type(), + &input_data_type, + ) { + return not_impl_err!( + "Unsupported CAST from {input_data_type} to {output_data_type}" + ); + } + } Ok(Transformed::no(expr)) }) .data() } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(Self { - inner: self.inner.with_partition_values(partition_values), - ..self.clone() - }) - } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_data_source/custom_file_format.rs similarity index 93% rename from datafusion-examples/examples/custom_file_format.rs rename to datafusion-examples/examples/custom_data_source/custom_file_format.rs index 67fe642fd46ee..6817beec41188 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_format.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::{any::Any, sync::Arc}; use arrow::{ @@ -25,12 +27,13 @@ use datafusion::{ catalog::Session, common::{GetExt, Statistics}, datasource::{ + MemTable, file_format::{ - csv::CsvFormatFactory, file_compression_type::FileCompressionType, - FileFormat, FileFormatFactory, + FileFormat, FileFormatFactory, csv::CsvFormatFactory, + file_compression_type::FileCompressionType, }, physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, - MemTable, + table_schema::TableSchema, }, error::Result, execution::session_state::SessionStateBuilder, @@ -47,6 +50,42 @@ use tempfile::tempdir; /// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. /// The former, once registered with the SessionState, will then be used /// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. +pub async fn custom_file_format() -> Result<()> { + // Create a new context with the default configuration + let mut state = SessionStateBuilder::new().with_default_features().build(); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true)?; + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table)?; + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} #[derive(Debug)] /// Custom file format that reads and writes TSV files @@ -128,8 +167,8 @@ impl FileFormat for TSVFileFormat { .await } - fn file_source(&self) -> Arc { - self.csv_file_format.file_source() + fn file_source(&self, table_schema: TableSchema) -> Arc { + self.csv_file_format.file_source(table_schema) } } @@ -180,44 +219,6 @@ impl GetExt for TSVFileFactory { } } -#[tokio::main] -async fn main() -> Result<()> { - // Create a new context with the default configuration - let mut state = SessionStateBuilder::new().with_default_features().build(); - - // Register the custom file format - let file_format = Arc::new(TSVFileFactory::new()); - state.register_file_format(file_format, true).unwrap(); - - // Create a new context with the custom file format - let ctx = SessionContext::new_with_state(state); - - let mem_table = create_mem_table(); - ctx.register_table("mem_table", mem_table).unwrap(); - - let temp_dir = tempdir().unwrap(); - let table_save_path = temp_dir.path().join("mem_table.tsv"); - - let d = ctx - .sql(&format!( - "COPY mem_table TO '{}' STORED AS TSV;", - table_save_path.display(), - )) - .await?; - - let results = d.collect().await?; - println!( - "Number of inserted rows: {:?}", - (results[0] - .column_by_name("count") - .unwrap() - .as_primitive::() - .value(0)) - ); - - Ok(()) -} - // create a simple mem table fn create_mem_table() -> Arc { let fields = vec![ diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/custom_data_source/default_column_values.rs similarity index 63% rename from datafusion-examples/examples/default_column_values.rs rename to datafusion-examples/examples/custom_data_source/default_column_values.rs index d3a7d2ec67f3c..81d74cfbecabd 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/custom_data_source/default_column_values.rs @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::DFSchema; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; @@ -37,12 +38,12 @@ use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::{lit, SessionConfig}; +use datafusion::prelude::{SessionConfig, lit}; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, }; use futures::StreamExt; use object_store::memory::InMemory; @@ -52,25 +53,22 @@ use object_store::{ObjectStore, PutPayload}; // Metadata key for storing default values in field metadata const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; -// Example showing how to implement custom default value handling for missing columns -// using field metadata and PhysicalExprAdapter. -// -// This example demonstrates how to: -// 1. Store default values in field metadata using a constant key -// 2. Create a custom PhysicalExprAdapter that reads these defaults -// 3. Inject default values for missing columns in filter predicates -// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation -// 5. Wrap string default values in cast expressions for proper type conversion -// -// Important: PhysicalExprAdapter is specifically designed for rewriting filter predicates -// that get pushed down to file scans. For handling missing columns in projections, -// other mechanisms in DataFusion are used (like SchemaAdapter). -// -// The metadata-based approach provides a flexible way to store default values as strings -// and cast them to the appropriate types at query time. - -#[tokio::main] -async fn main() -> Result<()> { +/// Example showing how to implement custom default value handling for missing columns +/// using field metadata and PhysicalExprAdapter. +/// +/// This example demonstrates how to: +/// 1. Store default values in field metadata using a constant key +/// 2. Create a custom PhysicalExprAdapter that reads these defaults +/// 3. Inject default values for missing columns in filter predicates using `replace_columns_with_literals` +/// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation +/// 5. Convert string default values to proper types using `ScalarValue::cast_to()` at planning time +/// +/// Important: PhysicalExprAdapter handles rewriting both filter predicates and projection +/// expressions for file scans, including handling missing columns. +/// +/// The metadata-based approach provides a flexible way to store default values as strings +/// and cast them to the appropriate types at planning time, avoiding runtime overhead. +pub async fn default_column_values() -> Result<()> { println!("=== Creating example data with missing columns and default values ==="); // Create sample data where the logical schema has more columns than the physical schema @@ -85,11 +83,10 @@ async fn main() -> Result<()> { .build(); let mut writer = - ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props)) - .expect("creating writer"); + ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props))?; - writer.write(&batch).expect("Writing batch"); - writer.close().unwrap(); + writer.write(&batch)?; + writer.close()?; buf }; let path = Path::from("example.parquet"); @@ -138,12 +135,14 @@ async fn main() -> Result<()> { println!("\n=== Key Insight ==="); println!("This example demonstrates how PhysicalExprAdapter works:"); println!("1. Physical schema only has 'id' and 'name' columns"); - println!("2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults"); - println!("3. Our custom adapter intercepts filter expressions on missing columns"); - println!("4. Default values from metadata are injected as cast expressions"); + println!( + "2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults" + ); + println!( + "3. Our custom adapter uses replace_columns_with_literals to inject default values" + ); + println!("4. Default values from metadata are cast to proper types at planning time"); println!("5. The DefaultPhysicalExprAdapter handles other schema adaptations"); - println!("\nNote: PhysicalExprAdapter is specifically for filter predicates."); - println!("For projection columns, different mechanisms handle missing columns."); Ok(()) } @@ -207,7 +206,7 @@ impl TableProvider for DefaultValueTableProvider { } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { @@ -228,14 +227,14 @@ impl TableProvider for DefaultValueTableProvider { filters: &[Expr], limit: Option, ) -> Result> { - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); let df_schema = DFSchema::try_from(schema.clone())?; let filter = state.create_physical_expr( conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), &df_schema, )?; - let parquet_source = ParquetSource::default() + let parquet_source = ParquetSource::new(schema.clone()) .with_predicate(filter) .with_pushdown_filters(true); @@ -257,10 +256,9 @@ impl TableProvider for DefaultValueTableProvider { let file_scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("memory://")?, - self.schema.clone(), Arc::new(parquet_source), ) - .with_projection_indices(projection.cloned()) + .with_projection_indices(projection.cloned())? .with_limit(limit) .with_file_group(file_group) .with_expr_adapter(Some(Arc::new(DefaultValuePhysicalExprAdapterFactory) as _)); @@ -282,14 +280,15 @@ impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { physical_file_schema: SchemaRef, ) -> Arc { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory - .create(logical_file_schema.clone(), physical_file_schema.clone()); + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + ); Arc::new(DefaultValuePhysicalExprAdapter { logical_file_schema, physical_file_schema, default_adapter, - partition_values: Vec::new(), }) } } @@ -301,98 +300,36 @@ struct DefaultValuePhysicalExprAdapter { logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, default_adapter: Arc, - partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { - // First try our custom default value injection for missing columns - let rewritten = expr - .transform(|expr| { - self.inject_default_values( - expr, - &self.logical_file_schema, - &self.physical_file_schema, - ) - }) - .data()?; - - // Then apply the default adapter as a fallback to handle standard schema differences - // like type casting, partition column handling, etc. - let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter - .with_partition_values(self.partition_values.clone()) - } else { - self.default_adapter.clone() - }; - - default_adapter.rewrite(rewritten) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(DefaultValuePhysicalExprAdapter { - logical_file_schema: self.logical_file_schema.clone(), - physical_file_schema: self.physical_file_schema.clone(), - default_adapter: self.default_adapter.clone(), - partition_values, - }) - } -} - -impl DefaultValuePhysicalExprAdapter { - fn inject_default_values( - &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - ) -> Result>> { - if let Some(column) = expr.as_any().downcast_ref::() { - let column_name = column.name(); - - // Check if this column exists in the physical schema - if physical_file_schema.index_of(column_name).is_err() { - // Column is missing from physical schema, check if logical schema has a default - if let Ok(logical_field) = - logical_file_schema.field_with_name(column_name) - { - if let Some(default_value_str) = - logical_field.metadata().get(DEFAULT_VALUE_METADATA_KEY) - { - // Create a string literal and wrap it in a cast expression - let default_literal = self.create_default_value_expr( - default_value_str, - logical_field.data_type(), - )?; - return Ok(Transformed::yes(default_literal)); - } - } + // Pre-compute replacements for missing columns with default values + let mut replacements = HashMap::new(); + for field in self.logical_file_schema.fields() { + // Skip columns that exist in physical schema + if self.physical_file_schema.index_of(field.name()).is_ok() { + continue; } - } - - // No transformation needed - Ok(Transformed::no(expr)) - } - fn create_default_value_expr( - &self, - value_str: &str, - data_type: &DataType, - ) -> Result> { - // Create a string literal with the default value - let string_literal = - Arc::new(Literal::new(ScalarValue::Utf8(Some(value_str.to_string())))); - - // If the target type is already Utf8, return the string literal directly - if matches!(data_type, DataType::Utf8) { - return Ok(string_literal); + // Check if this missing column has a default value in metadata + if let Some(default_str) = field.metadata().get(DEFAULT_VALUE_METADATA_KEY) { + // Create a Utf8 ScalarValue from the string and cast it to the target type + let string_value = ScalarValue::Utf8(Some(default_str.to_string())); + let typed_value = string_value.cast_to(field.data_type())?; + replacements.insert(field.name().as_str(), typed_value); + } } - // Otherwise, wrap the string literal in a cast expression - let cast_expr = Arc::new(CastExpr::new(string_literal, data_type.clone(), None)); + // Replace columns with their default literals if any + let rewritten = if !replacements.is_empty() { + let refs: HashMap<_, _> = replacements.iter().map(|(k, v)| (*k, v)).collect(); + replace_columns_with_literals(expr, &refs)? + } else { + expr + }; - Ok(cast_expr) + // Apply the default adapter as a fallback for other schema adaptations + self.default_adapter.rewrite(rewritten) } } diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs similarity index 90% rename from datafusion-examples/examples/file_stream_provider.rs rename to datafusion-examples/examples/custom_data_source/file_stream_provider.rs index e6c59d57e98de..936da0a33d47b 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs @@ -15,6 +15,31 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +/// Demonstrates how to use [`FileStreamProvider`] and [`StreamTable`] to stream data +/// from a file-like source (FIFO) into DataFusion for continuous querying. +/// +/// On non-Windows systems, this example creates a named pipe (FIFO) and +/// writes rows into it asynchronously while DataFusion reads the data +/// through a `FileStreamProvider`. +/// +/// This illustrates how to integrate dynamically updated data sources +/// with DataFusion without needing to reload the entire dataset each time. +/// +/// This example does not work on Windows. +pub async fn file_stream_provider() -> datafusion::error::Result<()> { + #[cfg(target_os = "windows")] + { + println!("file_stream_provider example does not work on windows"); + Ok(()) + } + #[cfg(not(target_os = "windows"))] + { + non_windows::main().await + } +} + #[cfg(not(target_os = "windows"))] mod non_windows { use datafusion::assert_batches_eq; @@ -22,8 +47,8 @@ mod non_windows { use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -34,9 +59,9 @@ mod non_windows { use tempfile::TempDir; use tokio::task::JoinSet; - use datafusion::common::{exec_err, Result}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; + use datafusion::common::{Result, exec_err}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::logical_expr::SortExpr; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -186,16 +211,3 @@ mod non_windows { Ok(()) } } - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - #[cfg(target_os = "windows")] - { - println!("file_stream_provider example does not work on windows"); - Ok(()) - } - #[cfg(not(target_os = "windows"))] - { - non_windows::main().await - } -} diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs new file mode 100644 index 0000000000000..b5dcf10f5cdaa --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples are all related to extending or defining how DataFusion reads data +//! +//! These examples demonstrate how DataFusion reads data. +//! +//! ## Usage +//! ```bash +//! cargo run --example custom_data_source -- [all|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `csv_json_opener` — use low level FileOpener APIs to read CSV/JSON into Arrow RecordBatches +//! - `csv_sql_streaming` — build and run a streaming query plan from a SQL statement against a local CSV file +//! - `custom_datasource` — run queries against a custom datasource (TableProvider) +//! - `custom_file_casts` — implement custom casting rules to adapt file schemas +//! - `custom_file_format` — write data to a custom file format +//! - `default_column_values` — implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter +//! - `file_stream_provider` — run a query on FileStreamProvider which implements StreamProvider for reading and writing to arbitrary stream sources/sinks + +mod csv_json_opener; +mod csv_sql_streaming; +mod custom_datasource; +mod custom_file_casts; +mod custom_file_format; +mod default_column_values; +mod file_stream_provider; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + CsvJsonOpener, + CsvSqlStreaming, + CustomDatasource, + CustomFileCasts, + CustomFileFormat, + DefaultColumnValues, + FileStreamProvider, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "custom_data_source"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, + ExampleKind::CsvSqlStreaming => { + csv_sql_streaming::csv_sql_streaming().await? + } + ExampleKind::CustomDatasource => { + custom_datasource::custom_datasource().await? + } + ExampleKind::CustomFileCasts => { + custom_file_casts::custom_file_casts().await? + } + ExampleKind::CustomFileFormat => { + custom_file_format::custom_file_format().await? + } + ExampleKind::DefaultColumnValues => { + default_column_values::default_column_values().await? + } + ExampleKind::FileStreamProvider => { + file_stream_provider::file_stream_provider().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/data_io/catalog.rs similarity index 98% rename from datafusion-examples/examples/catalog.rs rename to datafusion-examples/examples/data_io/catalog.rs index 229867cdfc5bb..d2ddff82e32db 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/data_io/catalog.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! Simple example of a catalog/schema implementation. use async_trait::async_trait; use datafusion::{ arrow::util::pretty, catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}, datasource::{ - file_format::{csv::CsvFormat, FileFormat}, - listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, TableProvider, + file_format::{FileFormat, csv::CsvFormat}, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, }, error::Result, execution::context::SessionState, @@ -34,8 +36,8 @@ use std::{any::Any, collections::HashMap, path::Path, sync::Arc}; use std::{fs::File, io::Write}; use tempfile::TempDir; -#[tokio::main] -async fn main() -> Result<()> { +/// Register the table into a custom catalog +pub async fn catalog() -> Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Info) .init(); @@ -134,12 +136,14 @@ struct DirSchemaOpts<'a> { dir: &'a Path, format: Arc, } + /// Schema where every file with extension `ext` in a given `dir` is a table. #[derive(Debug)] struct DirSchema { ext: String, tables: RwLock>>, } + impl DirSchema { async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result> { let DirSchemaOpts { ext, dir, format } = opts; @@ -172,6 +176,7 @@ impl DirSchema { ext: ext.to_string(), })) } + #[allow(unused)] fn name(&self) -> &str { &self.ext @@ -198,6 +203,7 @@ impl SchemaProvider for DirSchema { let tables = self.tables.read().unwrap(); tables.contains_key(name) } + fn register_table( &self, name: String, @@ -223,6 +229,7 @@ impl SchemaProvider for DirSchema { struct DirCatalog { schemas: RwLock>>, } + impl DirCatalog { fn new() -> Self { Self { @@ -230,10 +237,12 @@ impl DirCatalog { } } } + impl CatalogProvider for DirCatalog { fn as_any(&self) -> &dyn Any { self } + fn register_schema( &self, name: &str, @@ -260,11 +269,13 @@ impl CatalogProvider for DirCatalog { } } } + /// Catalog lists holds multiple catalog providers. Each context has a single catalog list. #[derive(Debug)] struct CustomCatalogProviderList { catalogs: RwLock>>, } + impl CustomCatalogProviderList { fn new() -> Self { Self { @@ -272,10 +283,12 @@ impl CustomCatalogProviderList { } } } + impl CatalogProviderList for CustomCatalogProviderList { fn as_any(&self) -> &dyn Any { self } + fn register_catalog( &self, name: String, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/data_io/json_shredding.rs similarity index 76% rename from datafusion-examples/examples/json_shredding.rs rename to datafusion-examples/examples/data_io/json_shredding.rs index 5ef8b59b64200..d2ffacc9464c2 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/data_io/json_shredding.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::sync::Arc; use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::common::{Result, assert_contains, exec_datafusion_err}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -37,7 +39,7 @@ use datafusion::logical_expr::{ use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_expr::{ScalarFunctionExpr, expressions}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -63,8 +65,7 @@ use object_store::{ObjectStore, PutPayload}; // 1. Push down predicates for better filtering // 2. Avoid expensive JSON parsing at query time // 3. Leverage columnar storage benefits for the materialized fields -#[tokio::main] -async fn main() -> Result<()> { +pub async fn json_shredding() -> Result<()> { println!("=== Creating example data with flat columns and underscore prefixes ==="); // Create sample data with flat columns using underscore prefixes @@ -232,7 +233,7 @@ impl ScalarUDFImpl for JsonGetStr { _ => { return Err(exec_datafusion_err!( "json_get_str first argument must be a string" - )) + )); } }; // We expect a string array that contains JSON strings @@ -248,7 +249,7 @@ impl ScalarUDFImpl for JsonGetStr { _ => { return Err(exec_datafusion_err!( "json_get_str second argument must be a string array" - )) + )); } }; let values = json_array @@ -276,14 +277,14 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { physical_file_schema: SchemaRef, ) -> Arc { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory - .create(logical_file_schema.clone(), physical_file_schema.clone()); + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + ); Arc::new(ShreddedJsonRewriter { - logical_file_schema, physical_file_schema, default_adapter, - partition_values: Vec::new(), }) } } @@ -292,10 +293,8 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { /// and wraps DefaultPhysicalExprAdapter for standard schema adaptation #[derive(Debug)] struct ShreddedJsonRewriter { - logical_file_schema: SchemaRef, physical_file_schema: SchemaRef, default_adapter: Arc, - partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for ShreddedJsonRewriter { @@ -306,27 +305,8 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { .data()?; // Then apply the default adapter as a fallback to handle standard schema differences - // like type casting, missing columns, and partition column handling - let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter - .with_partition_values(self.partition_values.clone()) - } else { - self.default_adapter.clone() - }; - - default_adapter.rewrite(rewritten) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - Arc::new(ShreddedJsonRewriter { - logical_file_schema: self.logical_file_schema.clone(), - physical_file_schema: self.physical_file_schema.clone(), - default_adapter: self.default_adapter.clone(), - partition_values, - }) + // like type casting and missing columns + self.default_adapter.rewrite(rewritten) } } @@ -336,44 +316,43 @@ impl ShreddedJsonRewriter { expr: Arc, physical_file_schema: &Schema, ) -> Result>> { - if let Some(func) = expr.as_any().downcast_ref::() { - if func.name() == "json_get_str" && func.args().len() == 2 { - // Get the key from the first argument - if let Some(literal) = func.args()[0] + if let Some(func) = expr.as_any().downcast_ref::() + && func.name() == "json_get_str" + && func.args().len() == 2 + { + // Get the key from the first argument + if let Some(literal) = func.args()[0] + .as_any() + .downcast_ref::() + && let ScalarValue::Utf8(Some(field_name)) = literal.value() + { + // Get the column from the second argument + if let Some(column) = func.args()[1] .as_any() - .downcast_ref::() + .downcast_ref::() { - if let ScalarValue::Utf8(Some(field_name)) = literal.value() { - // Get the column from the second argument - if let Some(column) = func.args()[1] - .as_any() - .downcast_ref::() - { - let column_name = column.name(); - // Check if there's a flat column with underscore prefix - let flat_column_name = format!("_{column_name}.{field_name}"); - - if let Ok(flat_field_index) = - physical_file_schema.index_of(&flat_column_name) - { - let flat_field = - physical_file_schema.field(flat_field_index); - - if flat_field.data_type() == &DataType::Utf8 { - // Replace the whole expression with a direct column reference - let new_expr = Arc::new(expressions::Column::new( - &flat_column_name, - flat_field_index, - )) - as Arc; - - return Ok(Transformed { - data: new_expr, - tnr: TreeNodeRecursion::Stop, - transformed: true, - }); - } - } + let column_name = column.name(); + // Check if there's a flat column with underscore prefix + let flat_column_name = format!("_{column_name}.{field_name}"); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) + { + let flat_field = physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); } } } diff --git a/datafusion-examples/examples/data_io/main.rs b/datafusion-examples/examples/data_io/main.rs new file mode 100644 index 0000000000000..496a588d4087a --- /dev/null +++ b/datafusion-examples/examples/data_io/main.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of data formats and I/O +//! +//! These examples demonstrate data formats and I/O. +//! +//! ## Usage +//! ```bash +//! cargo run --example data_io -- [all|catalog|json_shredding|parquet_adv_idx|parquet_emb_idx|parquet_enc_with_kms|parquet_enc|parquet_exec_visitor|parquet_idx|query_http_csv|remote_catalog] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `catalog` — register the table into a custom catalog +//! - `json_shredding` — shows how to implement custom filter rewriting for JSON shredding +//! - `parquet_adv_idx` — create a detailed secondary index that covers the contents of several parquet files +//! - `parquet_emb_idx` — store a custom index inside a Parquet file and use it to speed up queries +//! - `parquet_enc_with_kms` — read and write encrypted Parquet files using an encryption factory +//! - `parquet_enc` — read and write encrypted Parquet files using DataFusion +//! - `parquet_exec_visitor` — extract statistics by visiting an ExecutionPlan after execution +//! - `parquet_idx` — create an secondary index over several parquet files and use it to speed up queries +//! - `query_http_csv` — configure `object_store` and run a query against files via HTTP +//! - `remote_catalog` — interfacing with a remote catalog (e.g. over a network) + +mod catalog; +mod json_shredding; +mod parquet_advanced_index; +mod parquet_embedded_index; +mod parquet_encrypted; +mod parquet_encrypted_with_kms; +mod parquet_exec_visitor; +mod parquet_index; +mod query_http_csv; +mod remote_catalog; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Catalog, + JsonShredding, + ParquetAdvIdx, + ParquetEmbIdx, + ParquetEnc, + ParquetEncWithKms, + ParquetExecVisitor, + ParquetIdx, + QueryHttpCsv, + RemoteCatalog, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "data_io"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Catalog => catalog::catalog().await?, + ExampleKind::JsonShredding => json_shredding::json_shredding().await?, + ExampleKind::ParquetAdvIdx => { + parquet_advanced_index::parquet_advanced_index().await? + } + ExampleKind::ParquetEmbIdx => { + parquet_embedded_index::parquet_embedded_index().await? + } + ExampleKind::ParquetEncWithKms => { + parquet_encrypted_with_kms::parquet_encrypted_with_kms().await? + } + ExampleKind::ParquetEnc => parquet_encrypted::parquet_encrypted().await?, + ExampleKind::ParquetExecVisitor => { + parquet_exec_visitor::parquet_exec_visitor().await? + } + ExampleKind::ParquetIdx => parquet_index::parquet_index().await?, + ExampleKind::QueryHttpCsv => query_http_csv::query_http_csv().await?, + ExampleKind::RemoteCatalog => remote_catalog::remote_catalog().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/data_io/parquet_advanced_index.rs similarity index 98% rename from datafusion-examples/examples/advanced_parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_advanced_index.rs index 371c18de354ce..3f4ebe7a92055 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_advanced_index.rs @@ -15,40 +15,42 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use datafusion::catalog::Session; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; use datafusion::datasource::physical_plan::{ FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, }; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::arrow::arrow_reader::{ ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, }; use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::metadata::ParquetMetaData; use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; use datafusion::parquet::schema::types::ColumnPath; -use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_optimizer::pruning::PruningPredicate; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::*; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; @@ -56,8 +58,8 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use bytes::Bytes; use datafusion::datasource::memory::DataSourceExec; -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use object_store::ObjectStore; use tempfile::TempDir; use url::Url; @@ -155,8 +157,7 @@ use url::Url; /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable /// [Page Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_advanced_index() -> Result<()> { // the object store is used to read the parquet files (in this case, it is // a local file system, but in a real system it could be S3, GCS, etc) let object_store: Arc = @@ -239,6 +240,7 @@ pub struct IndexTableProvider { /// if true, use row selections in addition to row group selections use_row_selections: AtomicBool, } + impl IndexTableProvider { /// Create a new IndexTableProvider /// * `object_store` - the object store implementation to use for reading files @@ -491,19 +493,18 @@ impl TableProvider for IndexTableProvider { .with_file(indexed_file); let file_source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) // provide the predicate so the DataSourceExec can try and prune // row groups internally .with_predicate(predicate) // provide the factory to create parquet reader without re-reading metadata .with_parquet_file_reader_factory(Arc::new(reader_factory)), ); - let file_scan_config = - FileScanConfigBuilder::new(object_store_url, schema, file_source) - .with_limit(limit) - .with_projection_indices(projection.cloned()) - .with_file(partitioned_file) - .build(); + let file_scan_config = FileScanConfigBuilder::new(object_store_url, file_source) + .with_limit(limit) + .with_projection_indices(projection.cloned())? + .with_file(partitioned_file) + .build(); // Finally, put it all together into a DataSourceExec Ok(DataSourceExec::from_data_source(file_scan_config)) @@ -540,6 +541,7 @@ impl CachedParquetFileReaderFactory { metadata: HashMap::new(), } } + /// Add the pre-parsed information about the file to the factor fn with_file(mut self, indexed_file: &IndexedFile) -> Self { self.metadata.insert( diff --git a/datafusion-examples/examples/parquet_embedded_index.rs b/datafusion-examples/examples/data_io/parquet_embedded_index.rs similarity index 95% rename from datafusion-examples/examples/parquet_embedded_index.rs rename to datafusion-examples/examples/data_io/parquet_embedded_index.rs index 3cbe189147752..bcaca2ed5c85b 100644 --- a/datafusion-examples/examples/parquet_embedded_index.rs +++ b/datafusion-examples/examples/data_io/parquet_embedded_index.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! Embedding and using a custom index in Parquet files //! //! # Background @@ -116,11 +118,11 @@ use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::{exec_err, HashMap, HashSet, Result}; +use datafusion::common::{HashMap, HashSet, Result, exec_err}; +use datafusion::datasource::TableType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::DataSourceExec; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::datasource::TableType; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::{Operator, TableProviderFilterPushDown}; use datafusion::parquet::arrow::ArrowWriter; @@ -130,12 +132,37 @@ use datafusion::parquet::file::reader::{FileReader, SerializedFileReader}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; -use std::fs::{read_dir, File}; +use std::fs::{File, read_dir}; use std::io::{Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use tempfile::TempDir; +/// Store a custom index inside a Parquet file and use it to speed up queries +pub async fn parquet_embedded_index() -> Result<()> { + // 1. Create temp dir and write 3 Parquet files with different category sets + let tmp = TempDir::new()?; + let dir = tmp.path(); + write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; + write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; + write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; + + // 2. Register our custom TableProvider + let field = Field::new("category", DataType::Utf8, false); + let schema_ref = Arc::new(Schema::new(vec![field])); + let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); + + let ctx = SessionContext::new(); + ctx.register_table("t", provider)?; + + // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. + // based on the distinct index. + let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; + df.show().await?; + + Ok(()) +} + /// An index of distinct values for a single column /// /// In this example the index is a simple set of strings, but in a real @@ -392,21 +419,15 @@ impl TableProvider for DistinctIndexTable { // equality analysis or write your own custom logic. let mut target: Option<&str> = None; - if filters.len() == 1 { - if let Expr::BinaryExpr(expr) = &filters[0] { - if expr.op == Operator::Eq { - if let ( - Expr::Column(c), - Expr::Literal(ScalarValue::Utf8(Some(v)), _), - ) = (&*expr.left, &*expr.right) - { - if c.name == "category" { - println!("Filtering for category: {v}"); - target = Some(v); - } - } - } - } + if filters.len() == 1 + && let Expr::BinaryExpr(expr) = &filters[0] + && expr.op == Operator::Eq + && let (Expr::Column(c), Expr::Literal(ScalarValue::Utf8(Some(v)), _)) = + (&*expr.left, &*expr.right) + && c.name == "category" + { + println!("Filtering for category: {v}"); + target = Some(v); } // Determine which files to scan let files_to_scan: Vec<_> = self @@ -426,8 +447,10 @@ impl TableProvider for DistinctIndexTable { // Build ParquetSource to actually read the files let url = ObjectStoreUrl::parse("file://")?; - let source = Arc::new(ParquetSource::default().with_enable_page_index(true)); - let mut builder = FileScanConfigBuilder::new(url, self.schema.clone(), source); + let source = Arc::new( + ParquetSource::new(self.schema.clone()).with_enable_page_index(true), + ); + let mut builder = FileScanConfigBuilder::new(url, source); for file in files_to_scan { let path = self.dir.join(file); let len = std::fs::metadata(&path)?.len(); @@ -450,28 +473,3 @@ impl TableProvider for DistinctIndexTable { Ok(vec![TableProviderFilterPushDown::Inexact; fs.len()]) } } - -#[tokio::main] -async fn main() -> Result<()> { - // 1. Create temp dir and write 3 Parquet files with different category sets - let tmp = TempDir::new()?; - let dir = tmp.path(); - write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; - write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; - write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; - - // 2. Register our custom TableProvider - let field = Field::new("category", DataType::Utf8, false); - let schema_ref = Arc::new(Schema::new(vec![field])); - let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); - - let ctx = SessionContext::new(); - ctx.register_table("t", provider)?; - - // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. - // based on the distinct index. - let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/parquet_encrypted.rs b/datafusion-examples/examples/data_io/parquet_encrypted.rs similarity index 94% rename from datafusion-examples/examples/parquet_encrypted.rs rename to datafusion-examples/examples/data_io/parquet_encrypted.rs index 690d9f2a5f140..f88ab91321e91 100644 --- a/datafusion-examples/examples/parquet_encrypted.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::common::DataFusionError; use datafusion::config::{ConfigFileEncryptionProperties, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; @@ -25,8 +27,8 @@ use datafusion::prelude::{ParquetReadOptions, SessionContext}; use std::sync::Arc; use tempfile::TempDir; -#[tokio::main] -async fn main() -> datafusion::common::Result<()> { +/// Read and write encrypted Parquet files using DataFusion +pub async fn parquet_encrypted() -> datafusion::common::Result<()> { // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); @@ -73,7 +75,9 @@ async fn main() -> datafusion::common::Result<()> { let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; // Show information from the dataframe - println!("\n\n==============================================================================="); + println!( + "\n\n===============================================================================" + ); println!("Encrypted Parquet DataFrame:"); query_dataframe(&encrypted_parquet_df).await?; diff --git a/datafusion-examples/examples/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs similarity index 99% rename from datafusion-examples/examples/parquet_encrypted_with_kms.rs rename to datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs index 45bfd183773a0..1a9bf56c09b35 100644 --- a/datafusion-examples/examples/parquet_encrypted_with_kms.rs +++ b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow_schema::SchemaRef; use async_trait::async_trait; @@ -53,8 +55,7 @@ const ENCRYPTION_FACTORY_ID: &str = "example.mock_kms_encryption"; /// which is not a secure way to store encryption keys. /// For production use, it is recommended to use a key-management service (KMS) to encrypt /// data encryption keys. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_encrypted_with_kms() -> Result<()> { let ctx = SessionContext::new(); // Register an `EncryptionFactory` implementation to be used for Parquet encryption diff --git a/datafusion-examples/examples/parquet_exec_visitor.rs b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs similarity index 83% rename from datafusion-examples/examples/parquet_exec_visitor.rs rename to datafusion-examples/examples/data_io/parquet_exec_visitor.rs index 84f92d4f450e1..d38fe9e171205 100644 --- a/datafusion-examples/examples/parquet_exec_visitor.rs +++ b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use datafusion::datasource::file_format::parquet::ParquetFormat; @@ -25,13 +27,12 @@ use datafusion::error::DataFusionError; use datafusion::execution::context::SessionContext; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::{ - execute_stream, visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, + ExecutionPlan, ExecutionPlanVisitor, execute_stream, visit_execution_plan, }; use futures::StreamExt; /// Example of collecting metrics after execution by visiting the `ExecutionPlan` -#[tokio::main] -async fn main() { +pub async fn parquet_exec_visitor() -> datafusion::common::Result<()> { let ctx = SessionContext::new(); let test_data = datafusion::test_util::parquet_test_data(); @@ -51,8 +52,8 @@ async fn main() { ) .await; - let df = ctx.sql("SELECT * FROM my_table").await.unwrap(); - let plan = df.create_physical_plan().await.unwrap(); + let df = ctx.sql("SELECT * FROM my_table").await?; + let plan = df.create_physical_plan().await?; // Create empty visitor let mut visitor = ParquetExecVisitor { @@ -63,12 +64,12 @@ async fn main() { // Make sure you execute the plan to collect actual execution statistics. // For example, in this example the `file_scan_config` is known without executing // but the `bytes_scanned` would be None if we did not execute. - let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx()).unwrap(); + let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx())?; while let Some(batch) = batch_stream.next().await { println!("Batch rows: {}", batch.unwrap().num_rows()); } - visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + visit_execution_plan(plan.as_ref(), &mut visitor)?; println!( "ParquetExecVisitor bytes_scanned: {:?}", @@ -78,6 +79,8 @@ async fn main() { "ParquetExecVisitor file_groups: {:?}", visitor.file_groups.unwrap() ); + + Ok(()) } /// Define a struct with fields to hold the execution information you want to @@ -97,18 +100,17 @@ impl ExecutionPlanVisitor for ParquetExecVisitor { /// or `post_visit` (visit each node after its children/inputs) fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { // If needed match on a specific `ExecutionPlan` node type - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if let Some((file_config, _)) = + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && let Some((file_config, _)) = data_source_exec.downcast_to_file_source::() - { - self.file_groups = Some(file_config.file_groups.clone()); - - let metrics = match data_source_exec.metrics() { - None => return Ok(true), - Some(metrics) => metrics, - }; - self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); - } + { + self.file_groups = Some(file_config.file_groups.clone()); + + let metrics = match data_source_exec.metrics() { + None => return Ok(true), + Some(metrics) => metrics, + }; + self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); } Ok(true) } diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/data_io/parquet_index.rs similarity index 97% rename from datafusion-examples/examples/parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_index.rs index a1dd1f1ffd10d..e11a303f442a4 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_index.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ Array, ArrayRef, AsArray, BooleanArray, Int32Array, RecordBatch, StringArray, UInt64Array, @@ -25,19 +27,19 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::DataSourceExec; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::{ - utils::conjunction, TableProviderFilterPushDown, TableType, + TableProviderFilterPushDown, TableType, utils::conjunction, }; use datafusion::parquet::arrow::arrow_reader::statistics::StatisticsConverter; use datafusion::parquet::arrow::{ - arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, + ArrowWriter, arrow_reader::ParquetRecordBatchReaderBuilder, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_optimizer::pruning::PruningPredicate; @@ -50,8 +52,8 @@ use std::fs; use std::fs::{DirEntry, File}; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tempfile::TempDir; use url::Url; @@ -102,8 +104,7 @@ use url::Url; /// ``` /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_index() -> Result<()> { // Demo data has three files, each with schema // * file_name (string) // * value (int32) @@ -242,10 +243,11 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = Arc::new(ParquetSource::default().with_predicate(predicate)); + let source = + Arc::new(ParquetSource::new(self.schema()).with_predicate(predicate)); let mut file_scan_config_builder = - FileScanConfigBuilder::new(object_store_url, self.schema(), source) - .with_projection_indices(projection.cloned()) + FileScanConfigBuilder::new(object_store_url, source) + .with_projection_indices(projection.cloned())? .with_limit(limit); // Transform to the format needed to pass to DataSourceExec @@ -509,7 +511,7 @@ impl ParquetMetadataIndexBuilder { // Get the schema of the file. A real system might have to handle the // case where the schema of the file is not the same as the schema of - // the other files e.g. using SchemaAdapter. + // the other files e.g. using PhysicalExprAdapterFactory. if self.file_schema.is_none() { self.file_schema = Some(reader.schema().clone()); } diff --git a/datafusion-examples/examples/query-http-csv.rs b/datafusion-examples/examples/data_io/query_http_csv.rs similarity index 91% rename from datafusion-examples/examples/query-http-csv.rs rename to datafusion-examples/examples/data_io/query_http_csv.rs index fa3fd2ac068df..71421e6270ccb 100644 --- a/datafusion-examples/examples/query-http-csv.rs +++ b/datafusion-examples/examples/data_io/query_http_csv.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::http::HttpBuilder; use std::sync::Arc; use url::Url; -/// This example demonstrates executing a simple query against an Arrow data source (CSV) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { +/// Configure `object_store` and run a query against files via HTTP +pub async fn query_http_csv() -> Result<()> { // create local execution context let ctx = SessionContext::new(); diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/data_io/remote_catalog.rs similarity index 98% rename from datafusion-examples/examples/remote_catalog.rs rename to datafusion-examples/examples/data_io/remote_catalog.rs index 74575554ec0af..10ec26b1d5c05 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/data_io/remote_catalog.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// This example shows how to implement the DataFusion [`CatalogProvider`] API /// for catalogs that are remote (require network access) and/or offer only /// asynchronous APIs such as [Polaris], [Unity], and [Hive]. @@ -39,15 +41,15 @@ use datafusion::common::{assert_batches_eq, internal_datafusion_err, plan_err}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::SendableRecordBatchStream; use datafusion::logical_expr::{Expr, TableType}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::prelude::{DataFrame, SessionContext}; use futures::TryStreamExt; use std::any::Any; use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<()> { +/// Interfacing with a remote catalog (e.g. over a network) +pub async fn remote_catalog() -> Result<()> { // As always, we create a session context to interact with DataFusion let ctx = SessionContext::new(); diff --git a/datafusion-examples/examples/dataframe/cache_factory.rs b/datafusion-examples/examples/dataframe/cache_factory.rs new file mode 100644 index 0000000000000..a6c465720c626 --- /dev/null +++ b/datafusion-examples/examples/dataframe/cache_factory.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::Arc; +use std::sync::RwLock; + +use arrow::array::RecordBatch; +use async_trait::async_trait; +use datafusion::catalog::memory::MemorySourceConfig; +use datafusion::common::DFSchemaRef; +use datafusion::error::Result; +use datafusion::execution::SessionState; +use datafusion::execution::SessionStateBuilder; +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::CacheFactory; +use datafusion::logical_expr::Extension; +use datafusion::logical_expr::LogicalPlan; +use datafusion::logical_expr::UserDefinedLogicalNode; +use datafusion::logical_expr::UserDefinedLogicalNodeCore; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::collect_partitioned; +use datafusion::physical_planner::DefaultPhysicalPlanner; +use datafusion::physical_planner::ExtensionPlanner; +use datafusion::physical_planner::PhysicalPlanner; +use datafusion::prelude::ParquetReadOptions; +use datafusion::prelude::SessionContext; +use datafusion::prelude::*; +use datafusion_common::HashMap; + +/// This example demonstrates how to leverage [CacheFactory] to implement custom caching strategies for dataframes in DataFusion. +/// By default, [DataFrame::cache] in Datafusion is eager and creates an in-memory table. This example shows a basic alternative implementation for lazy caching. +/// Specifically, it implements: +/// - A [CustomCacheFactory] that creates a logical node [CacheNode] representing the cache operation. +/// - A [CacheNodePlanner] (an [ExtensionPlanner]) that understands [CacheNode] and performs caching. +/// - A [CacheNodeQueryPlanner] that installs [CacheNodePlanner]. +/// - A simple in-memory [CacheManager] that stores cached [RecordBatch]es. Note that the implementation for this example is very naive and only implements put, but for real production use cases cache eviction and drop should also be implemented. +pub async fn cache_dataframe_with_custom_logic() -> Result<()> { + let testdata = datafusion::test_util::parquet_test_data(); + let filename = &format!("{testdata}/alltypes_plain.parquet"); + + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(CustomCacheFactory {}))) + .with_query_planner(Arc::new(CacheNodeQueryPlanner::default())) + .build(); + let ctx = SessionContext::new_with_state(session_state); + + // Read the parquet files and show its schema using 'describe' + let parquet_df = ctx + .read_parquet(filename, ParquetReadOptions::default()) + .await?; + + let df_cached = parquet_df + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(1)))? + .cache() + .await?; + + let df1 = df_cached.clone().filter(col("bool_col").is_true())?; + let df2 = df1.clone().sort(vec![col("id").sort(true, false)])?; + + // should see log for caching only once + df_cached.show().await?; + df1.show().await?; + df2.show().await?; + + Ok(()) +} + +#[derive(Debug)] +struct CustomCacheFactory {} + +impl CacheFactory for CustomCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +#[derive(PartialEq, Eq, PartialOrd, Hash, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size must be one"); + Ok(Self { + input: inputs.swap_remove(0), + }) + } +} + +struct CacheNodePlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl ExtensionPlanner for CacheNodePlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + if let Some(cache_node) = node.as_any().downcast_ref::() { + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + if self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .is_none() + { + let ctx = session_state.task_ctx(); + println!("caching in memory"); + let batches = + collect_partitioned(physical_inputs[0].clone(), ctx).await?; + self.cache_manager + .write() + .unwrap() + .put(cache_node.input.clone(), batches); + } else { + println!("fetching directly from cache manager"); + } + Ok(self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .map(|batches| { + let exec: Arc = MemorySourceConfig::try_new_exec( + batches, + physical_inputs[0].schema(), + None, + ) + .unwrap(); + exec + })) + } else { + Ok(None) + } + } +} + +#[derive(Debug, Default)] +struct CacheNodeQueryPlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl QueryPlanner for CacheNodeQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + CacheNodePlanner { + cache_manager: Arc::clone(&self.cache_manager), + }, + )]); + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +// This naive implementation only includes put, but for real production use cases cache eviction and drop should also be implemented. +#[derive(Debug, Default)] +struct CacheManager { + cache: HashMap>>, +} + +impl CacheManager { + pub fn put(&mut self, k: LogicalPlan, v: Vec>) { + self.cache.insert(k, v); + } + + pub fn get(&self, k: &LogicalPlan) -> Option<&Vec>> { + self.cache.get(k) + } +} diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe/dataframe.rs similarity index 90% rename from datafusion-examples/examples/dataframe.rs rename to datafusion-examples/examples/dataframe/dataframe.rs index a5ee571a14764..94653e80c8695 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe/dataframe.rs @@ -15,22 +15,23 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::catalog::MemTable; +use datafusion::common::ScalarValue; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; -use datafusion::common::DataFusionError; -use datafusion::common::ScalarValue; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; -use std::fs::File; +use std::fs::{File, create_dir_all}; use std::io::Write; use std::sync::Arc; -use tempfile::tempdir; +use tempfile::{TempDir, tempdir}; /// This example demonstrates using DataFusion's DataFrame API /// @@ -39,6 +40,7 @@ use tempfile::tempdir; /// * [read_parquet]: execute queries against parquet files /// * [read_csv]: execute queries against csv files /// * [read_memory]: execute queries against in-memory arrow data +/// * [read_memory_macro]: execute queries against in-memory arrow data using macro /// /// # Writing out to local storage /// @@ -53,12 +55,7 @@ use tempfile::tempdir; /// * [where_scalar_subquery]: execute a scalar subquery /// * [where_in_subquery]: execute a subquery with an IN clause /// * [where_exist_subquery]: execute a subquery with an EXISTS clause -/// -/// # Querying data -/// -/// * [query_to_date]: execute queries against parquet files -#[tokio::main] -async fn main() -> Result<()> { +pub async fn dataframe_example() -> Result<()> { env_logger::init(); // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); @@ -199,7 +196,7 @@ async fn read_memory_macro() -> Result<()> { /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file -async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { +async fn write_out(ctx: &SessionContext) -> Result<()> { let array = StringViewArray::from(vec!["a", "b", "c"]); let schema = Arc::new(Schema::new(vec![Field::new( "tablecol1", @@ -211,15 +208,26 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr ctx.register_table("initial_data", Arc::new(mem_table))?; let df = ctx.table("initial_data").await?; - ctx.sql( - "create external table - test(tablecol1 varchar) - stored as parquet - location './datafusion-examples/test_table/'", - ) - .await? - .collect() - .await?; + // Create a single temp root with subdirectories + let tmp_root = TempDir::new()?; + let examples_root = tmp_root.path().join("datafusion-examples"); + create_dir_all(&examples_root)?; + let table_dir = examples_root.join("test_table"); + let parquet_dir = examples_root.join("test_parquet"); + let csv_dir = examples_root.join("test_csv"); + let json_dir = examples_root.join("test_json"); + create_dir_all(&table_dir)?; + create_dir_all(&parquet_dir)?; + create_dir_all(&csv_dir)?; + create_dir_all(&json_dir)?; + + let create_sql = format!( + "CREATE EXTERNAL TABLE test(tablecol1 varchar) + STORED AS parquet + LOCATION '{}'", + table_dir.display() + ); + ctx.sql(&create_sql).await?.collect().await?; // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). // The behavior of write_table depends on the TableProvider's implementation @@ -230,7 +238,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_parquet( - "./datafusion-examples/test_parquet/", + parquet_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) @@ -238,7 +246,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_csv( - "./datafusion-examples/test_csv/", + csv_dir.to_str().unwrap(), // DataFrameWriteOptions contains options which control how data is written // such as compression codec DataFrameWriteOptions::new(), @@ -248,7 +256,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_json( - "./datafusion-examples/test_json/", + json_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs similarity index 98% rename from datafusion-examples/examples/deserialize_to_struct.rs rename to datafusion-examples/examples/dataframe/deserialize_to_struct.rs index d6655b3b654f9..e19d45554131a 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{AsArray, PrimitiveArray}; use arrow::datatypes::{Float64Type, Int32Type}; use datafusion::common::assert_batches_eq; @@ -29,8 +31,7 @@ use futures::StreamExt; /// as [ArrayRef] /// /// [ArrayRef]: arrow::array::ArrayRef -#[tokio::main] -async fn main() -> Result<()> { +pub async fn deserialize_to_struct() -> Result<()> { // Run a query that returns two columns of data let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/dataframe/main.rs b/datafusion-examples/examples/dataframe/main.rs new file mode 100644 index 0000000000000..9a2604e97136d --- /dev/null +++ b/datafusion-examples/examples/dataframe/main.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are core DataFrame API usage +//! +//! These examples demonstrate core DataFrame API usage. +//! +//! ## Usage +//! ```bash +//! cargo run --example dataframe -- [all|dataframe|deserialize_to_struct] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `dataframe` — run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries +//! - `deserialize_to_struct` — convert query results (Arrow ArrayRefs) into Rust structs + +mod cache_factory; +mod dataframe; +mod deserialize_to_struct; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Dataframe, + DeserializeToStruct, + CacheFactory, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "dataframe"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Dataframe => { + dataframe::dataframe_example().await?; + } + ExampleKind::DeserializeToStruct => { + deserialize_to_struct::deserialize_to_struct().await?; + } + ExampleKind::CacheFactory => { + cache_factory::cache_dataframe_with_custom_logic().await?; + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/execution_monitoring/main.rs b/datafusion-examples/examples/execution_monitoring/main.rs new file mode 100644 index 0000000000000..3043a80363086 --- /dev/null +++ b/datafusion-examples/examples/execution_monitoring/main.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of memory and performance management +//! +//! These examples demonstrate memory and performance management. +//! +//! ## Usage +//! ```bash +//! cargo run --example execution_monitoring -- [all|mem_pool_exec_plan|mem_pool_tracking|tracing] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `mem_pool_exec_plan` — shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +//! - `mem_pool_tracking` — demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages +//! - `tracing` — demonstrates the tracing injection feature for the DataFusion runtime + +mod memory_pool_execution_plan; +mod memory_pool_tracking; +mod tracing; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MemPoolExecPlan, + MemPoolTracking, + Tracing, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "execution_monitoring"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MemPoolExecPlan => { + memory_pool_execution_plan::memory_pool_execution_plan().await? + } + ExampleKind::MemPoolTracking => { + memory_pool_tracking::mem_pool_tracking().await? + } + ExampleKind::Tracing => tracing::tracing().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/memory_pool_execution_plan.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs similarity index 97% rename from datafusion-examples/examples/memory_pool_execution_plan.rs rename to datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs index 3258cde17625f..48475acbb1542 100644 --- a/datafusion-examples/examples/memory_pool_execution_plan.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to implement custom ExecutionPlans that properly //! use memory tracking through TrackConsumersPool. //! @@ -28,7 +30,7 @@ use arrow::record_batch::RecordBatch; use arrow_schema::SchemaRef; use datafusion::common::record_batch; use datafusion::common::{exec_datafusion_err, internal_err}; -use datafusion::datasource::{memory::MemTable, DefaultTableSource}; +use datafusion::datasource::{DefaultTableSource, memory::MemTable}; use datafusion::error::Result; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; @@ -44,8 +46,8 @@ use std::any::Any; use std::fmt; use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<(), Box> { +/// Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +pub async fn memory_pool_execution_plan() -> Result<()> { println!("=== DataFusion ExecutionPlan Memory Tracking Example ===\n"); // Set up a runtime with memory tracking @@ -140,6 +142,7 @@ impl ExternalBatchBufferer { } } + #[expect(clippy::needless_pass_by_value)] fn add_batch(&mut self, batch_data: Vec) -> Result<()> { let additional_memory = batch_data.len(); diff --git a/datafusion-examples/examples/memory_pool_tracking.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs similarity index 95% rename from datafusion-examples/examples/memory_pool_tracking.rs rename to datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs index d5823b1173ab3..8d6e5dd7e444d 100644 --- a/datafusion-examples/examples/memory_pool_tracking.rs +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to use TrackConsumersPool for memory tracking and debugging. //! //! The TrackConsumersPool provides enhanced error messages that show the top memory consumers @@ -24,11 +26,12 @@ //! //! * [`automatic_usage_example`]: Shows how to use RuntimeEnvBuilder to automatically enable memory tracking +use datafusion::error::Result; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::prelude::*; -#[tokio::main] -async fn main() -> Result<(), Box> { +/// Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages +pub async fn mem_pool_tracking() -> Result<()> { println!("=== DataFusion Memory Pool Tracking Example ===\n"); // Example 1: Automatic Usage with RuntimeEnvBuilder @@ -41,7 +44,7 @@ async fn main() -> Result<(), Box> { /// /// This shows the recommended way to use TrackConsumersPool through RuntimeEnvBuilder, /// which automatically creates a TrackConsumersPool with sensible defaults. -async fn automatic_usage_example() -> datafusion::error::Result<()> { +async fn automatic_usage_example() -> Result<()> { println!("Example 1: Automatic Usage with RuntimeEnvBuilder"); println!("------------------------------------------------"); diff --git a/datafusion-examples/examples/tracing.rs b/datafusion-examples/examples/execution_monitoring/tracing.rs similarity index 92% rename from datafusion-examples/examples/tracing.rs rename to datafusion-examples/examples/execution_monitoring/tracing.rs index 334ee0f4e5686..5fa759f2d541d 100644 --- a/datafusion-examples/examples/tracing.rs +++ b/datafusion-examples/examples/execution_monitoring/tracing.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates the tracing injection feature for the DataFusion runtime. //! Tasks spawned on new threads behave differently depending on whether a tracer is injected. //! The log output clearly distinguishes the two cases. @@ -49,20 +51,20 @@ //! 10:29:40.809 INFO main ThreadId(01) tracing: ***** WITH tracer: Non-main tasks DID inherit the `run_instrumented_query` span ***** //! ``` -use datafusion::common::runtime::{set_join_set_tracer, JoinSetTracer}; +use datafusion::common::runtime::{JoinSetTracer, set_join_set_tracer}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; use datafusion::test_util::parquet_test_data; -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use std::any::Any; use std::sync::Arc; -use tracing::{info, instrument, Instrument, Level, Span}; +use tracing::{Instrument, Level, Span, info, instrument}; -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates the tracing injection feature for the DataFusion runtime +pub async fn tracing() -> Result<()> { // Initialize tracing subscriber with thread info. tracing_subscriber::fmt() .with_thread_ids(true) @@ -73,7 +75,9 @@ async fn main() -> Result<()> { // Run query WITHOUT tracer injection. info!("***** RUNNING WITHOUT INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****"); + info!( + "***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****" + ); // Inject custom tracer so tasks run in the current span. info!("Injecting custom tracer..."); @@ -82,7 +86,9 @@ async fn main() -> Result<()> { // Run query WITH tracer injection. info!("***** RUNNING WITH INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****"); + info!( + "***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****" + ); Ok(()) } diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs similarity index 87% rename from datafusion-examples/examples/external_dependency/dataframe-to-s3.rs rename to datafusion-examples/examples/external_dependency/dataframe_to_s3.rs index e75ba5dd5328a..fdb8a3c9c051a 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::env; use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; @@ -28,14 +30,18 @@ use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; use url::Url; -/// This example demonstrates querying data from AmazonS3 and writing -/// the result of a query back to AmazonS3 -#[tokio::main] -async fn main() -> Result<()> { +/// This example demonstrates querying data from Amazon S3 and writing +/// the result of a query back to Amazon S3. +/// +/// The following environment variables must be defined: +/// +/// - AWS_ACCESS_KEY_ID +/// - AWS_SECRET_ACCESS_KEY +pub async fn dataframe_to_s3() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - //enter region and bucket to which your credentials have GET and PUT access + // enter region and bucket to which your credentials have GET and PUT access let region = ""; let bucket_name = ""; @@ -66,13 +72,13 @@ async fn main() -> Result<()> { .write_parquet(&out_path, DataFrameWriteOptions::new(), None) .await?; - //write as JSON to s3 + // write as JSON to s3 let json_out = format!("s3://{bucket_name}/json_out"); df.clone() .write_json(&json_out, DataFrameWriteOptions::new(), None) .await?; - //write as csv to s3 + // write as csv to s3 let csv_out = format!("s3://{bucket_name}/csv_out"); df.write_csv(&csv_out, DataFrameWriteOptions::new(), None) .await?; diff --git a/datafusion-examples/examples/external_dependency/main.rs b/datafusion-examples/examples/external_dependency/main.rs new file mode 100644 index 0000000000000..abcba61421bdb --- /dev/null +++ b/datafusion-examples/examples/external_dependency/main.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are using data from Amazon S3 examples +//! +//! These examples demonstrate how to work with data from Amazon S3. +//! +//! ## Usage +//! ```bash +//! cargo run --example external_dependency -- [all|dataframe_to_s3|query_aws_s3] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `dataframe_to_s3` — run a query using a DataFrame against a parquet file from AWS S3 and writing back to AWS S3 +//! - `query_aws_s3` — configure `object_store` and run a query against files stored in AWS S3 + +mod dataframe_to_s3; +mod query_aws_s3; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + DataframeToS3, + QueryAwsS3, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "external_dependency"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DataframeToS3 => dataframe_to_s3::dataframe_to_s3().await?, + ExampleKind::QueryAwsS3 => query_aws_s3::query_aws_s3().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/external_dependency/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query_aws_s3.rs similarity index 90% rename from datafusion-examples/examples/external_dependency/query-aws-s3.rs rename to datafusion-examples/examples/external_dependency/query_aws_s3.rs index cd0b4562d5f2d..63507bb3eed11 100644 --- a/datafusion-examples/examples/external_dependency/query-aws-s3.rs +++ b/datafusion-examples/examples/external_dependency/query_aws_s3.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; @@ -22,14 +24,13 @@ use std::env; use std::sync::Arc; use url::Url; -/// This example demonstrates querying data in an S3 bucket. +/// This example demonstrates querying data in a public S3 bucket +/// (the NYC TLC open dataset: `s3://nyc-tlc`). /// /// The following environment variables must be defined: -/// -/// - AWS_ACCESS_KEY_ID -/// - AWS_SECRET_ACCESS_KEY -#[tokio::main] -async fn main() -> Result<()> { +/// - `AWS_ACCESS_KEY_ID` +/// - `AWS_SECRET_ACCESS_KEY` +pub async fn query_aws_s3() -> Result<()> { let ctx = SessionContext::new(); // the region must be set to the region where the bucket exists until the following diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f054..eb217ef9e4832 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; @@ -34,7 +35,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,7 +53,7 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new_with_ffi_codec(Arc::new(table_provider), true, None, codec) } #[export_root_module] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml index 612a219324763..f393b2971e454 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "ffi_module_interface" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false [dependencies] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 88690e9297135..3b2b9e1871dae 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -16,12 +16,12 @@ // under the License. use abi_stable::{ - declare_root_module_statics, + StableAbi, declare_root_module_statics, library::{LibraryError, RootModule}, package_version_strings, sabi_types::VersionStrings, - StableAbi, }; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; #[repr(C)] @@ -34,7 +34,8 @@ use datafusion_ffi::table_provider::FFI_TableProvider; /// how a user may wish to separate these concerns. pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, + pub create_table: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_TableProvider, } impl RootModule for TableProviderModuleRef { diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 028a366aab1c0..823c9afddee2a 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "ffi_module_loader" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false [dependencies] diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8f..8ce5b156df3b1 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -22,8 +22,10 @@ use datafusion::{ prelude::SessionContext, }; -use abi_stable::library::{development_utils::compute_library_path, RootModule}; -use datafusion_ffi::table_provider::ForeignTableProvider; +use abi_stable::library::{RootModule, development_utils::compute_library_path}; +use datafusion::datasource::TableProvider; +use datafusion::execution::TaskContextProvider; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use ffi_module_interface::TableProviderModuleRef; #[tokio::main] @@ -39,6 +41,11 @@ async fn main() -> Result<()> { TableProviderModuleRef::load_from_directory(&library_path) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let ctx = Arc::new(SessionContext::new()); + let codec = FFI_LogicalExtensionCodec::new_default( + &(Arc::clone(&ctx) as Arc), + ); + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = @@ -46,16 +53,14 @@ async fn main() -> Result<()> { .create_table() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), - ))?(); + ))?(codec); // In order to access the table provider within this executable, we need to - // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + // turn it into a `TableProvider`. + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion-examples/examples/flight/client.rs b/datafusion-examples/examples/flight/client.rs index 031beea47d57a..484576975a6f2 100644 --- a/datafusion-examples/examples/flight/client.rs +++ b/datafusion-examples/examples/flight/client.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; use tonic::transport::Endpoint; diff --git a/datafusion-examples/examples/flight/main.rs b/datafusion-examples/examples/flight/main.rs index a448789b353b9..25965a3011c60 100644 --- a/datafusion-examples/examples/flight/main.rs +++ b/datafusion-examples/examples/flight/main.rs @@ -19,7 +19,16 @@ //! //! These examples demonstrate Arrow Flight usage. //! +//! ## Usage +//! ```bash +//! cargo run --example flight -- [all|client|server|sql_server] +//! ``` +//! //! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! Note: The Flight server must be started in a separate process +//! before running the `client` example. Therefore, running `all` will +//! not produce a full server+client workflow automatically. //! - `client` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol //! - `server` — run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol //! - `sql_server` — run DataFusion as a standalone process and execute SQL queries from JDBC clients @@ -28,46 +37,43 @@ mod client; mod server; mod sql_server; -use std::str::FromStr; - use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; +/// The `all` option cannot run all examples end-to-end because the +/// `server` example must run in a separate process before the `client` +/// example can connect. +/// Therefore, `all` only iterates over individually runnable examples. +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] enum ExampleKind { + All, Client, Server, SqlServer, } -impl AsRef for ExampleKind { - fn as_ref(&self) -> &str { - match self { - Self::Client => "client", - Self::Server => "server", - Self::SqlServer => "sql_server", - } - } -} - -impl FromStr for ExampleKind { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - match s { - "client" => Ok(Self::Client), - "server" => Ok(Self::Server), - "sql_server" => Ok(Self::SqlServer), - _ => Err(DataFusionError::Execution(format!("Unknown example: {s}"))), - } - } -} - impl ExampleKind { - const ALL: [Self; 3] = [Self::Client, Self::Server, Self::SqlServer]; - const EXAMPLE_NAME: &str = "flight"; - fn variants() -> Vec<&'static str> { - Self::ALL.iter().map(|x| x.as_ref()).collect() + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<(), Box> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Client => client::client().await?, + ExampleKind::Server => server::server().await?, + ExampleKind::SqlServer => sql_server::sql_server().await?, + } + Ok(()) } } @@ -76,19 +82,14 @@ async fn main() -> Result<(), Box> { let usage = format!( "Usage: cargo run --example {} -- [{}]", ExampleKind::EXAMPLE_NAME, - ExampleKind::variants().join("|") + ExampleKind::VARIANTS.join("|") ); - let arg = std::env::args().nth(1).ok_or_else(|| { - eprintln!("{usage}"); - DataFusionError::Execution("Missing argument".to_string()) - })?; - - match arg.parse::()? { - ExampleKind::Client => client::client().await?, - ExampleKind::Server => server::server().await?, - ExampleKind::SqlServer => sql_server::sql_server().await?, - } + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; - Ok(()) + example.run().await } diff --git a/datafusion-examples/examples/flight/server.rs b/datafusion-examples/examples/flight/server.rs index dc75287cf2e2b..aad82e28b15ef 100644 --- a/datafusion-examples/examples/flight/server.rs +++ b/datafusion-examples/examples/flight/server.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; @@ -29,9 +31,9 @@ use tonic::{Request, Response, Status, Streaming}; use datafusion::prelude::*; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, }; #[derive(Clone)] @@ -187,6 +189,7 @@ impl FlightService for FlightServiceImpl { } } +#[expect(clippy::needless_pass_by_value)] fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { Status::internal(format!("{e:?}")) } diff --git a/datafusion-examples/examples/flight/sql_server.rs b/datafusion-examples/examples/flight/sql_server.rs index d86860f9d4364..435e05ffc0cec 100644 --- a/datafusion-examples/examples/flight/sql_server.rs +++ b/datafusion-examples/examples/flight/sql_server.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::ipc::writer::IpcWriteOptions; @@ -414,7 +416,9 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result<(), Status> { let handle = std::str::from_utf8(&handle.prepared_statement_handle); if let Ok(handle) = handle { - info!("do_action_close_prepared_statement: removing plan and results for {handle}"); + info!( + "do_action_close_prepared_statement: removing plan and results for {handle}" + ); let _ = self.remove_plan(handle); let _ = self.remove_result(handle); } diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/proto/composed_extension_codec.rs similarity index 95% rename from datafusion-examples/examples/composed_extension_codec.rs rename to datafusion-examples/examples/proto/composed_extension_codec.rs index 57f2c370413aa..f3910d461b6a8 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/proto/composed_extension_codec.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to compose multiple PhysicalExtensionCodecs //! //! This can be helpful when an Execution plan tree has different nodes from different crates @@ -34,8 +36,8 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use datafusion::common::internal_err; use datafusion::common::Result; +use datafusion::common::internal_err; use datafusion::execution::TaskContext; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; @@ -44,8 +46,8 @@ use datafusion_proto::physical_plan::{ }; use datafusion_proto::protobuf; -#[tokio::main] -async fn main() { +/// Example of using multiple extension codecs for serialization / deserialization +pub async fn composed_extension_codec() -> Result<()> { // build execution plan that has both types of nodes // // Note each node requires a different `PhysicalExtensionCodec` to decode @@ -66,16 +68,16 @@ async fn main() { protobuf::PhysicalPlanNode::try_from_physical_plan( exec_plan.clone(), &composed_codec, - ) - .expect("to proto"); + )?; // deserialize proto back to execution plan - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) - .expect("from proto"); + let result_exec_plan: Arc = + proto.try_into_physical_plan(&ctx.task_ctx(), &composed_codec)?; // assert that the original and deserialized execution plans are equal assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + + Ok(()) } /// This example has two types of nodes: `ParentExec` and `ChildExec` which can only diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs new file mode 100644 index 0000000000000..9e4ae728206c4 --- /dev/null +++ b/datafusion-examples/examples/proto/main.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Examples demonstrating DataFusion's plan serialization via the `datafusion-proto` crate +//! +//! These examples show how to use multiple extension codecs for serialization / deserialization. +//! +//! ## Usage +//! ```bash +//! cargo run --example proto -- [all|composed_extension_codec] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `composed_extension_codec` — example of using multiple extension codecs for serialization / deserialization + +mod composed_extension_codec; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + ComposedExtensionCodec, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "proto"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::ComposedExtensionCodec => { + composed_extension_codec::composed_extension_codec().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/query_planning/analyzer_rule.rs similarity index 97% rename from datafusion-examples/examples/analyzer_rule.rs rename to datafusion-examples/examples/query_planning/analyzer_rule.rs index cb81cd167a88b..a86f5cdd2a5e3 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/query_planning/analyzer_rule.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::common::Result; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::Result; -use datafusion::logical_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder, col, lit}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::prelude::SessionContext; use std::sync::{Arc, Mutex}; @@ -35,8 +37,7 @@ use std::sync::{Arc, Mutex}; /// level access control scheme by introducing a filter to the query. /// /// See [optimizer_rule.rs] for an example of a optimizer rule -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn analyzer_rule() -> Result<()> { // AnalyzerRules run before OptimizerRules. // // DataFusion includes several built in AnalyzerRules for tasks such as type diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/query_planning/expr_api.rs similarity index 97% rename from datafusion-examples/examples/expr_api.rs rename to datafusion-examples/examples/query_planning/expr_api.rs index 56f960870e58a..47de669023f7c 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/query_planning/expr_api.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{BooleanArray, Int32Array, Int8Array}; +use arrow::array::{BooleanArray, Int8Array, Int32Array}; use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; @@ -35,7 +37,7 @@ use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; use datafusion::optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; +use datafusion::physical_expr::{AnalysisContext, ExprBoundaries, analyze}; use datafusion::prelude::*; /// This example demonstrates the DataFusion [`Expr`] API. @@ -55,8 +57,7 @@ use datafusion::prelude::*; /// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] /// 6. Get the types of the expressions: [`expression_type_demo`] /// 7. Apply type coercion to expressions: [`type_coercion_demo`] -#[tokio::main] -async fn main() -> Result<()> { +pub async fn expr_api() -> Result<()> { // The easiest way to do create expressions is to use the // "fluent"-style API: let expr = col("a") + lit(5); @@ -302,6 +303,7 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; // We can then build our expression boundaries from the column statistics @@ -342,9 +344,11 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { // // (a' - b' + 1) / (a - b) // (10000 - 5000 + 1) / (10000 - 1) - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity)) + ); Ok(()) } @@ -369,6 +373,7 @@ fn boundary_analysis_in_conjunctions_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(14))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; let initial_boundaries = @@ -414,9 +419,11 @@ fn boundary_analysis_in_conjunctions_demo() -> Result<()> { // // Granted a column such as age will more likely follow a Normal distribution // as such our selectivity estimation will not be as good as it can. - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity)) + ); // The above example was a good way to look at how we can derive better // interval and get a lower selectivity during boundary analysis. @@ -532,10 +539,11 @@ fn type_coercion_demo() -> Result<()> { let physical_expr = datafusion::physical_expr::create_physical_expr(&expr, &df_schema, &props)?; let e = physical_expr.evaluate(&batch).unwrap_err(); - assert!(e - .find_root() - .to_string() - .contains("Invalid comparison operation: Int8 > Int32")); + assert!( + e.find_root() + .to_string() + .contains("Invalid comparison operation: Int8 > Int32") + ); // 1. Type coercion with `SessionContext::create_physical_expr` which implicitly applies type coercion before constructing the physical expr. let physical_expr = diff --git a/datafusion-examples/examples/query_planning/main.rs b/datafusion-examples/examples/query_planning/main.rs new file mode 100644 index 0000000000000..247f468735359 --- /dev/null +++ b/datafusion-examples/examples/query_planning/main.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are all internal mechanics of the query planning and optimization layers +//! +//! These examples demonstrate internal mechanics of the query planning and optimization layers. +//! +//! ## Usage +//! ```bash +//! cargo run --example query_planning -- [all|analyzer_rule|expr_api|optimizer_rule|parse_sql_expr|plan_to_sql|planner_api|pruning|thread_pools] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `analyzer_rule` — use a custom AnalyzerRule to change a query's semantics (row level access control) +//! - `expr_api` — create, execute, simplify, analyze and coerce `Expr`s +//! - `optimizer_rule` — use a custom OptimizerRule to replace certain predicates +//! - `parse_sql_expr` — parse SQL text into DataFusion `Expr` +//! - `plan_to_sql` — generate SQL from DataFusion `Expr` and `LogicalPlan` +//! - `planner_api` — APIs to manipulate logical and physical plans +//! - `pruning` — APIs to manipulate logical and physical plans +//! - `thread_pools` — demonstrate TrackConsumersPool for memory tracking and debugging with enhanced error messages and shows how to implement memory-aware ExecutionPlan with memory reservation and spilling + +mod analyzer_rule; +mod expr_api; +mod optimizer_rule; +mod parse_sql_expr; +mod plan_to_sql; +mod planner_api; +mod pruning; +mod thread_pools; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AnalyzerRule, + ExprApi, + OptimizerRule, + ParseSqlExpr, + PlanToSql, + PlannerApi, + Pruning, + ThreadPools, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "query_planning"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AnalyzerRule => analyzer_rule::analyzer_rule().await?, + ExampleKind::ExprApi => expr_api::expr_api().await?, + ExampleKind::OptimizerRule => optimizer_rule::optimizer_rule().await?, + ExampleKind::ParseSqlExpr => parse_sql_expr::parse_sql_expr().await?, + ExampleKind::PlanToSql => plan_to_sql::plan_to_sql_examples().await?, + ExampleKind::PlannerApi => planner_api::planner_api().await?, + ExampleKind::Pruning => pruning::pruning().await?, + ExampleKind::ThreadPools => thread_pools::thread_pools().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/query_planning/optimizer_rule.rs similarity index 98% rename from datafusion-examples/examples/optimizer_rule.rs rename to datafusion-examples/examples/query_planning/optimizer_rule.rs index 9c137b67432c5..de9de7737a6a0 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/query_planning/optimizer_rule.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{assert_batches_eq, Result, ScalarValue}; +use datafusion::common::{Result, ScalarValue, assert_batches_eq}; use datafusion::logical_expr::{ BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -37,8 +39,7 @@ use std::sync::Arc; /// /// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for /// changing plan semantics. -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn optimizer_rule() -> Result<()> { // DataFusion includes many built in OptimizerRules for tasks such as outer // to inner join conversion and constant folding. // diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/query_planning/parse_sql_expr.rs similarity index 96% rename from datafusion-examples/examples/parse_sql_expr.rs rename to datafusion-examples/examples/query_planning/parse_sql_expr.rs index 5387e7c4a05dc..376120de9d492 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/query_planning/parse_sql_expr.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; use datafusion::logical_expr::{col, lit}; @@ -32,17 +34,15 @@ use datafusion::{ /// The code in this example shows how to: /// /// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical -/// expression using a schema at [`SessionContext`]. +/// expression using a schema at [`SessionContext`]. /// /// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression -/// using a schema at [`DataFrame`]. +/// using a schema at [`DataFrame`]. /// /// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame. /// /// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`]. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parse_sql_expr() -> Result<()> { // See how to evaluate expressions simple_session_context_parse_sql_expr_demo()?; simple_dataframe_parse_sql_expr_demo().await?; diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/query_planning/plan_to_sql.rs similarity index 95% rename from datafusion-examples/examples/plan_to_sql.rs rename to datafusion-examples/examples/query_planning/plan_to_sql.rs index 54483b143a169..756cc80b8f3c7 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/query_planning/plan_to_sql.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::common::DFSchemaRef; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::Statement; @@ -32,7 +34,7 @@ use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparse use datafusion::sql::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; -use datafusion::sql::unparser::{plan_to_sql, Unparser}; +use datafusion::sql::unparser::{Unparser, plan_to_sql}; use std::fmt; use std::sync::Arc; @@ -43,28 +45,26 @@ use std::sync::Arc; /// The code in this example shows how to: /// /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with -/// fluent API and convert to sql suitable for passing to another database +/// fluent API and convert to sql suitable for passing to another database /// /// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without extra parentheses, -/// suitable for displaying to humans +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple -/// expression [`Exprs`] with fluent API and convert to sql escaping column -/// names in MySQL style. +/// expression [`Exprs`] with fluent API and convert to sql escaping column +/// names in MySQL style. /// /// 4. [`simple_plan_to_sql_demo`]: Create a simple logical plan using the -/// DataFrames API and convert to sql string. +/// DataFrames API and convert to sql string. /// /// 5. [`round_trip_plan_to_sql_demo`]: Create a logical plan from a SQL string, modify it using the -/// DataFrames API and convert it back to a sql string. +/// DataFrames API and convert it back to a sql string. /// /// 6. [`unparse_my_logical_plan_as_statement`]: Create a custom logical plan and unparse it as a statement. /// /// 7. [`unparse_my_logical_plan_as_subquery`]: Create a custom logical plan and unparse it as a subquery. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn plan_to_sql_examples() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; simple_expr_to_pretty_sql_demo()?; diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/query_planning/planner_api.rs similarity index 98% rename from datafusion-examples/examples/planner_api.rs rename to datafusion-examples/examples/query_planning/planner_api.rs index 55aec7b0108a4..9b8aa1c2fe649 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/query_planning/planner_api.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::displayable; @@ -32,8 +34,7 @@ use datafusion::prelude::*; /// physical plan: /// - Via the combined `create_physical_plan` API. /// - Utilizing the analyzer, optimizer, and query planner APIs separately. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn planner_api() -> Result<()> { // Set up a DataFusion context and load a Parquet file let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/query_planning/pruning.rs similarity index 97% rename from datafusion-examples/examples/pruning.rs rename to datafusion-examples/examples/query_planning/pruning.rs index 9a61789662cdd..33f3f8428a77f 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/query_planning/pruning.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashSet; use std::sync::Arc; @@ -22,6 +24,7 @@ use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; +use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_optimizer::pruning::PruningPredicate; @@ -40,8 +43,7 @@ use datafusion::prelude::*; /// one might do as part of a higher level storage engine. See /// `parquet_index.rs` for an example that uses pruning in the context of an /// individual query. -#[tokio::main] -async fn main() { +pub async fn pruning() -> Result<()> { // In this example, we'll use the PruningPredicate to determine if // the expression `x = 5 AND y = 10` can never be true based on statistics @@ -69,7 +71,7 @@ async fn main() { let predicate = create_pruning_predicate(expr, &my_catalog.schema); // Evaluate the predicate for the three files in the catalog - let prune_results = predicate.prune(&my_catalog).unwrap(); + let prune_results = predicate.prune(&my_catalog)?; println!("Pruning results: {prune_results:?}"); // The result is a `Vec` of bool values, one for each file in the catalog @@ -93,6 +95,8 @@ async fn main() { false ] ); + + Ok(()) } /// A simple model catalog that has information about the three files that store @@ -186,6 +190,7 @@ impl PruningStatistics for MyCatalog { } } +#[expect(clippy::needless_pass_by_value)] fn create_pruning_predicate(expr: Expr, schema: &SchemaRef) -> PruningPredicate { let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap(); let props = ExecutionProps::new(); diff --git a/datafusion-examples/examples/thread_pools.rs b/datafusion-examples/examples/query_planning/thread_pools.rs similarity index 99% rename from datafusion-examples/examples/thread_pools.rs rename to datafusion-examples/examples/query_planning/thread_pools.rs index 9842cccfbfe83..6fc7d51e91c1f 100644 --- a/datafusion-examples/examples/thread_pools.rs +++ b/datafusion-examples/examples/query_planning/thread_pools.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to use separate thread pools (tokio [`Runtime`]))s to //! run the IO and CPU intensive parts of DataFusion plans. //! @@ -64,8 +66,7 @@ use url::Url; /// when using Rust libraries such as `tonic`. Using a separate `Runtime` for /// CPU bound tasks will often be simpler in larger applications, even though it /// makes this example slightly more complex. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn thread_pools() -> Result<()> { // The first two examples read local files. Enabling the URL table feature // lets us treat filenames as tables in SQL. let ctx = SessionContext::new().enable_url_table(); @@ -121,7 +122,7 @@ async fn same_runtime(ctx: &SessionContext, sql: &str) -> Result<()> { // Executing the plan using this pattern intermixes any IO and CPU intensive // work on same Runtime while let Some(batch) = stream.next().await { - println!("{}", pretty_format_batches(&[batch?]).unwrap()); + println!("{}", pretty_format_batches(&[batch?])?); } Ok(()) } diff --git a/datafusion-examples/examples/relation_planner/main.rs b/datafusion-examples/examples/relation_planner/main.rs new file mode 100644 index 0000000000000..d2ba2202d1787 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/main.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Relation Planner Examples +//! +//! These examples demonstrate how to use custom relation planners to extend +//! DataFusion's SQL syntax with custom table operators. +//! +//! ## Usage +//! ```bash +//! cargo run --example relation_planner -- [all|match_recognize|pivot_unpivot|table_sample] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `match_recognize` — MATCH_RECOGNIZE pattern matching on event streams +//! - `pivot_unpivot` — PIVOT and UNPIVOT operations for reshaping data +//! - `table_sample` — TABLESAMPLE clause for sampling rows from tables +//! +//! ## Snapshot Testing +//! +//! These examples use [insta](https://insta.rs) for inline snapshot assertions. +//! If query output changes, regenerate the snapshots with: +//! ```bash +//! cargo insta test --example relation_planner --accept +//! ``` + +mod match_recognize; +mod pivot_unpivot; +mod table_sample; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MatchRecognize, + PivotUnpivot, + TableSample, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "relation_planner"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MatchRecognize => match_recognize::match_recognize().await?, + ExampleKind::PivotUnpivot => pivot_unpivot::pivot_unpivot().await?, + ExampleKind::TableSample => table_sample::table_sample().await?, + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} + +/// Test wrappers that enable `cargo insta test --example relation_planner --accept` +/// to regenerate inline snapshots. Without these, insta cannot run the examples +/// in test mode since they only have `main()` functions. +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_match_recognize() { + match_recognize::match_recognize().await.unwrap(); + } + + #[tokio::test] + async fn test_pivot_unpivot() { + pivot_unpivot::pivot_unpivot().await.unwrap(); + } + + #[tokio::test] + async fn test_table_sample() { + table_sample::table_sample().await.unwrap(); + } +} diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs new file mode 100644 index 0000000000000..60baf9bd61a62 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -0,0 +1,406 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # MATCH_RECOGNIZE Example +//! +//! This example demonstrates implementing SQL `MATCH_RECOGNIZE` pattern matching +//! using a custom [`RelationPlanner`]. Unlike the [`pivot_unpivot`] example that +//! rewrites SQL to standard operations, this example creates a **custom logical +//! plan node** (`MiniMatchRecognizeNode`) to represent the operation. +//! +//! ## Supported Syntax +//! +//! ```sql +//! SELECT * FROM events +//! MATCH_RECOGNIZE ( +//! PARTITION BY region +//! MEASURES SUM(price) AS total, AVG(price) AS average +//! PATTERN (A B+ C) +//! DEFINE +//! A AS price < 100, +//! B AS price BETWEEN 100 AND 200, +//! C AS price > 200 +//! ) AS matches +//! ``` +//! +//! ## Architecture +//! +//! This example demonstrates **logical planning only**. Physical execution would +//! require implementing an [`ExecutionPlan`] (see the [`table_sample`] example +//! for a complete implementation with physical planning). +//! +//! ```text +//! SQL Query +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MatchRecognizePlanner │ +//! │ (RelationPlanner trait) │ +//! │ │ +//! │ • Parses MATCH_RECOGNIZE syntax │ +//! │ • Creates MiniMatchRecognizeNode │ +//! │ • Converts SQL exprs to DataFusion │ +//! └─────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MiniMatchRecognizeNode │ +//! │ (UserDefinedLogicalNode) │ +//! │ │ +//! │ • measures: [(alias, expr), ...] │ +//! │ • definitions: [(symbol, expr), ...]│ +//! └─────────────────────────────────────┘ +//! ``` +//! +//! [`pivot_unpivot`]: super::pivot_unpivot +//! [`table_sample`]: super::table_sample +//! [`ExecutionPlan`]: datafusion::physical_plan::ExecutionPlan + +use std::{any::Any, cmp::Ordering, hash::Hasher, sync::Arc}; + +use arrow::array::{ArrayRef, Float64Array, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{DFSchemaRef, Result}; +use datafusion_expr::{ + Expr, UserDefinedLogicalNode, + logical_plan::{Extension, InvariantLevel, LogicalPlan}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the MATCH_RECOGNIZE examples demonstrating pattern matching on event streams. +/// +/// Note: This example demonstrates **logical planning only**. Physical execution +/// would require additional implementation of an [`ExecutionPlan`]. +pub async fn match_recognize() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(MatchRecognizePlanner))?; + register_sample_data(&ctx)?; + + println!("MATCH_RECOGNIZE Example (Logical Planning Only)"); + println!("================================================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Basic MATCH_RECOGNIZE with MEASURES and DEFINE + // Demonstrates: Aggregate measures over matched rows + let plan = run_example( + ctx, + "Example 1: MATCH_RECOGNIZE with aggregations", + r#"SELECT * FROM events + MATCH_RECOGNIZE ( + PARTITION BY 1 + MEASURES SUM(price) AS total_price, AVG(price) AS avg_price + PATTERN (A) + DEFINE A AS price > 10 + ) AS matches"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: matches.price + SubqueryAlias: matches + MiniMatchRecognize measures=[total_price := sum(events.price), avg_price := avg(events.price)] define=[a := events.price > Int64(10)] + TableScan: events + "); + + // Example 2: Stock price pattern detection + // Demonstrates: Real-world use case finding prices above threshold + let plan = run_example( + ctx, + "Example 2: Detect high stock prices", + r#"SELECT * FROM stock_prices + MATCH_RECOGNIZE ( + MEASURES + MIN(price) AS min_price, + MAX(price) AS max_price, + AVG(price) AS avg_price + PATTERN (HIGH) + DEFINE HIGH AS price > 151.0 + ) AS trends"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: trends.symbol, trends.price + SubqueryAlias: trends + MiniMatchRecognize measures=[min_price := min(stock_prices.price), max_price := max(stock_prices.price), avg_price := avg(stock_prices.price)] define=[high := stock_prices.price > Float64(151)] + TableScan: stock_prices + "); + + Ok(()) +} + +/// Helper to run a single example query and display the logical plan. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let plan = ctx.sql(sql).await?.into_unoptimized_plan(); + let plan_str = plan.display_indent().to_string(); + println!("{plan_str}\n"); + Ok(plan_str) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // events: simple price series + ctx.register_batch( + "events", + RecordBatch::try_from_iter(vec![( + "price", + Arc::new(Int32Array::from(vec![5, 12, 8, 15, 20])) as ArrayRef, + )])?, + )?; + + // stock_prices: realistic stock data + ctx.register_batch( + "stock_prices", + RecordBatch::try_from_iter(vec![ + ( + "symbol", + Arc::new(StringArray::from(vec!["DDOG", "DDOG", "DDOG", "DDOG"])) + as ArrayRef, + ), + ( + "price", + Arc::new(Float64Array::from(vec![150.0, 155.0, 152.0, 158.0])), + ), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Logical Plan Node: MiniMatchRecognizeNode +// ============================================================================ + +/// A custom logical plan node representing MATCH_RECOGNIZE operations. +/// +/// This is a simplified implementation that captures the essential structure: +/// - `measures`: Aggregate expressions computed over matched rows +/// - `definitions`: Symbol definitions (predicate expressions) +/// +/// A production implementation would also include: +/// - Pattern specification (regex-like pattern) +/// - Partition and order by clauses +/// - Output mode (ONE ROW PER MATCH, ALL ROWS PER MATCH) +/// - After match skip strategy +#[derive(Debug)] +struct MiniMatchRecognizeNode { + input: Arc, + schema: DFSchemaRef, + /// Measures: (alias, aggregate_expr) + measures: Vec<(String, Expr)>, + /// Symbol definitions: (symbol_name, predicate_expr) + definitions: Vec<(String, Expr)>, +} + +impl UserDefinedLogicalNode for MiniMatchRecognizeNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "MiniMatchRecognize" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { + Ok(()) + } + + fn expressions(&self) -> Vec { + self.measures + .iter() + .chain(&self.definitions) + .map(|(_, expr)| expr.clone()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MiniMatchRecognize")?; + + if !self.measures.is_empty() { + write!(f, " measures=[")?; + for (i, (alias, expr)) in self.measures.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{alias} := {expr}")?; + } + write!(f, "]")?; + } + + if !self.definitions.is_empty() { + write!(f, " define=[")?; + for (i, (symbol, expr)) in self.definitions.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{symbol} := {expr}")?; + } + write!(f, "]")?; + } + + Ok(()) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result> { + let expected_len = self.measures.len() + self.definitions.len(); + if exprs.len() != expected_len { + return Err(datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize: expected {expected_len} expressions, got {}", + exprs.len() + )); + } + + let input = inputs.into_iter().next().ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize requires exactly one input" + ) + })?; + + let (measure_exprs, definition_exprs) = exprs.split_at(self.measures.len()); + + let measures = self + .measures + .iter() + .zip(measure_exprs) + .map(|((alias, _), expr)| (alias.clone(), expr.clone())) + .collect(); + + let definitions = self + .definitions + .iter() + .zip(definition_exprs) + .map(|((symbol, _), expr)| (symbol.clone(), expr.clone())) + .collect(); + + Ok(Arc::new(Self { + input: Arc::new(input), + schema: Arc::clone(&self.schema), + measures, + definitions, + })) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + state.write_usize(Arc::as_ptr(&self.input) as usize); + state.write_usize(self.measures.len()); + state.write_usize(self.definitions.len()); + } + + fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { + other.as_any().downcast_ref::().is_some_and(|o| { + Arc::ptr_eq(&self.input, &o.input) + && self.measures == o.measures + && self.definitions == o.definitions + }) + } + + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option { + if self.dyn_eq(other) { + Some(Ordering::Equal) + } else { + None + } + } +} + +// ============================================================================ +// Relation Planner: MatchRecognizePlanner +// ============================================================================ + +/// Relation planner that creates `MiniMatchRecognizeNode` for MATCH_RECOGNIZE queries. +#[derive(Debug)] +struct MatchRecognizePlanner; + +impl RelationPlanner for MatchRecognizePlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + let TableFactor::MatchRecognize { + table, + measures, + symbols, + alias, + .. + } = relation + else { + return Ok(RelationPlanning::Original(relation)); + }; + + // Plan the input table + let input = ctx.plan(*table)?; + let schema = input.schema().clone(); + + // Convert MEASURES: SQL expressions → DataFusion expressions + let planned_measures: Vec<(String, Expr)> = measures + .iter() + .map(|m| { + let alias = ctx.normalize_ident(m.alias.clone()); + let expr = ctx.sql_to_expr(m.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>()?; + + // Convert DEFINE: symbol definitions → DataFusion expressions + let planned_definitions: Vec<(String, Expr)> = symbols + .iter() + .map(|s| { + let name = ctx.normalize_ident(s.symbol.clone()); + let expr = ctx.sql_to_expr(s.definition.clone(), schema.as_ref())?; + Ok((name, expr)) + }) + .collect::>()?; + + // Create the custom node + let node = MiniMatchRecognizeNode { + input: Arc::new(input), + schema, + measures: planned_measures, + definitions: planned_definitions, + }; + + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(node), + }); + + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } +} diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs new file mode 100644 index 0000000000000..86a6cb955500e --- /dev/null +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -0,0 +1,567 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # PIVOT and UNPIVOT Example +//! +//! This example demonstrates implementing SQL `PIVOT` and `UNPIVOT` operations +//! using a custom [`RelationPlanner`]. Unlike the other examples that create +//! custom logical/physical nodes, this example shows how to **rewrite** SQL +//! constructs into equivalent standard SQL operations: +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- PIVOT: Transform rows into columns +//! SELECT * FROM sales +//! PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4')) +//! +//! -- UNPIVOT: Transform columns into rows +//! SELECT * FROM wide_table +//! UNPIVOT (value FOR name IN (col1, col2, col3)) +//! ``` +//! +//! ## Rewrite Strategy +//! +//! **PIVOT** is rewritten to `GROUP BY` with `CASE` expressions: +//! ```sql +//! -- Original: +//! SELECT * FROM sales PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) +//! +//! -- Rewritten to: +//! SELECT region, +//! SUM(CASE quarter WHEN 'Q1' THEN amount END) AS Q1, +//! SUM(CASE quarter WHEN 'Q2' THEN amount END) AS Q2 +//! FROM sales +//! GROUP BY region +//! ``` +//! +//! **UNPIVOT** is rewritten to `UNION ALL` of projections: +//! ```sql +//! -- Original: +//! SELECT * FROM wide UNPIVOT (sales FOR quarter IN (q1, q2)) +//! +//! -- Rewritten to: +//! SELECT region, 'q1' AS quarter, q1 AS sales FROM wide +//! UNION ALL +//! SELECT region, 'q2' AS quarter, q2 AS sales FROM wide +//! ``` + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue, plan_datafusion_err}; +use datafusion_expr::{ + Expr, case, col, lit, + logical_plan::builder::LogicalPlanBuilder, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{NullInclusion, PivotValueSource, TableFactor}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the PIVOT/UNPIVOT examples demonstrating data reshaping operations. +pub async fn pivot_unpivot() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(PivotUnpivotPlanner))?; + register_sample_data(&ctx)?; + + println!("PIVOT and UNPIVOT Example"); + println!("=========================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // ----- PIVOT Examples ----- + + // Example 1: Basic PIVOT + // Transforms: (region, quarter, amount) → (region, Q1, Q2) + let results = run_example( + ctx, + "Example 1: Basic PIVOT", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+------+------+ + | region | Q1 | Q2 | + +--------+------+------+ + | North | 1000 | 1500 | + | South | 1200 | 1300 | + +--------+------+------+ + "); + + // Example 2: PIVOT with multiple aggregates + // Creates columns for each (aggregate, value) combination + let results = run_example( + ctx, + "Example 2: PIVOT with multiple aggregates", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount), AVG(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+--------+--------+--------+--------+ + | region | sum_Q1 | sum_Q2 | avg_Q1 | avg_Q2 | + +--------+--------+--------+--------+--------+ + | North | 1000 | 1500 | 1000.0 | 1500.0 | + | South | 1200 | 1300 | 1200.0 | 1300.0 | + +--------+--------+--------+--------+--------+ + "); + + // Example 3: PIVOT with multiple grouping columns + // Non-pivot, non-aggregate columns become GROUP BY columns + let results = run_example( + ctx, + "Example 3: PIVOT with multiple grouping columns", + r#"SELECT * FROM product_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region, product"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+----------+-----+-----+ + | region | product | Q1 | Q2 | + +--------+----------+-----+-----+ + | North | ProductA | 500 | | + | North | ProductB | 500 | | + | South | ProductA | | 650 | + +--------+----------+-----+-----+ + "); + + // ----- UNPIVOT Examples ----- + + // Example 4: Basic UNPIVOT + // Transforms: (region, q1, q2) → (region, quarter, sales) + let results = run_example( + ctx, + "Example 4: Basic UNPIVOT", + r#"SELECT * FROM wide_sales + UNPIVOT (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 5: UNPIVOT with INCLUDE NULLS + // By default, UNPIVOT excludes rows where the value column is NULL. + // INCLUDE NULLS keeps them (same result here since no NULLs in data). + let results = run_example( + ctx, + "Example 5: UNPIVOT INCLUDE NULLS", + r#"SELECT * FROM wide_sales + UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 6: PIVOT with column projection + // Standard SQL operations work seamlessly after PIVOT + let results = run_example( + ctx, + "Example 6: PIVOT with projection", + r#"SELECT region FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+ + | region | + +--------+ + | North | + | South | + +--------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // quarterly_sales: normalized sales data (region, quarter, amount) + ctx.register_batch( + "quarterly_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South", "South"])) + as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q2", "Q1", "Q2"])), + ), + ( + "amount", + Arc::new(Int64Array::from(vec![1000, 1500, 1200, 1300])), + ), + ])?, + )?; + + // product_sales: sales with additional grouping dimension + ctx.register_batch( + "product_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South"])) as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q1", "Q2"])), + ), + ( + "product", + Arc::new(StringArray::from(vec!["ProductA", "ProductB", "ProductA"])), + ), + ("amount", Arc::new(Int64Array::from(vec![500, 500, 650]))), + ])?, + )?; + + // wide_sales: denormalized/wide format (for UNPIVOT) + ctx.register_batch( + "wide_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "South"])) as ArrayRef, + ), + ("q1", Arc::new(Int64Array::from(vec![1000, 1200]))), + ("q2", Arc::new(Int64Array::from(vec![1500, 1300]))), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Relation Planner: PivotUnpivotPlanner +// ============================================================================ + +/// Relation planner that rewrites PIVOT and UNPIVOT into standard SQL. +#[derive(Debug)] +struct PivotUnpivotPlanner; + +impl RelationPlanner for PivotUnpivotPlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + alias, + .. + } => plan_pivot( + ctx, + *table, + &aggregate_functions, + &value_column, + value_source, + alias, + ), + + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => plan_unpivot( + ctx, + *table, + &value, + name, + &columns, + null_inclusion.as_ref(), + alias, + ), + + other => Ok(RelationPlanning::Original(other)), + } + } +} + +// ============================================================================ +// PIVOT Implementation +// ============================================================================ + +/// Rewrite PIVOT to GROUP BY with CASE expressions. +fn plan_pivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + aggregate_functions: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + value_column: &[datafusion_sql::sqlparser::ast::Expr], + value_source: PivotValueSource, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Parse aggregate functions + let aggregates: Vec = aggregate_functions + .iter() + .map(|agg| ctx.sql_to_expr(agg.expr.clone(), schema.as_ref())) + .collect::>()?; + + // Get the pivot column (only single-column pivot supported) + if value_column.len() != 1 { + return Err(plan_datafusion_err!( + "Only single-column PIVOT is supported" + )); + } + let pivot_col = ctx.sql_to_expr(value_column[0].clone(), schema.as_ref())?; + let pivot_col_name = extract_column_name(&pivot_col)?; + + // Parse pivot values + let pivot_values = match value_source { + PivotValueSource::List(list) => list + .iter() + .map(|item| { + let alias = item + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())); + let expr = ctx.sql_to_expr(item.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>>()?, + _ => { + return Err(plan_datafusion_err!( + "Dynamic PIVOT (ANY/Subquery) is not supported" + )); + } + }; + + // Determine GROUP BY columns (non-pivot, non-aggregate columns) + let agg_input_cols: Vec<&str> = aggregates + .iter() + .filter_map(|agg| { + if let Expr::AggregateFunction(f) = agg { + f.params.args.first().and_then(|e| { + if let Expr::Column(c) = e { + Some(c.name.as_str()) + } else { + None + } + }) + } else { + None + } + }) + .collect(); + + let group_by_cols: Vec = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .filter(|name| *name != pivot_col_name.as_str() && !agg_input_cols.contains(name)) + .map(col) + .collect(); + + // Build CASE expressions for each (aggregate, pivot_value) pair + let mut pivot_exprs = Vec::new(); + for agg in &aggregates { + let Expr::AggregateFunction(agg_fn) = agg else { + continue; + }; + let Some(agg_input) = agg_fn.params.args.first().cloned() else { + continue; + }; + + for (value_alias, pivot_value) in &pivot_values { + // CASE pivot_col WHEN pivot_value THEN agg_input END + let case_expr = case(col(&pivot_col_name)) + .when(pivot_value.clone(), agg_input.clone()) + .end()?; + + // Wrap in aggregate function + let pivoted = agg_fn.func.call(vec![case_expr]); + + // Determine column alias + let value_str = value_alias + .clone() + .unwrap_or_else(|| expr_to_string(pivot_value)); + let col_alias = if aggregates.len() > 1 { + format!("{}_{}", agg_fn.func.name(), value_str) + } else { + value_str + }; + + pivot_exprs.push(pivoted.alias(col_alias)); + } + } + + let plan = LogicalPlanBuilder::from(input) + .aggregate(group_by_cols, pivot_exprs)? + .build()?; + + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) +} + +// ============================================================================ +// UNPIVOT Implementation +// ============================================================================ + +/// Rewrite UNPIVOT to UNION ALL of projections. +fn plan_unpivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + value: &datafusion_sql::sqlparser::ast::Expr, + name: datafusion_sql::sqlparser::ast::Ident, + columns: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + null_inclusion: Option<&NullInclusion>, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Output column names + let value_col_name = value.to_string(); + let name_col_name = ctx.normalize_ident(name); + + // Parse columns to unpivot: (source_column, label) + let unpivot_cols: Vec<(String, String)> = columns + .iter() + .map(|c| { + let label = c + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())) + .unwrap_or_else(|| c.expr.to_string()); + let expr = ctx.sql_to_expr(c.expr.clone(), schema.as_ref())?; + let col_name = extract_column_name(&expr)?; + Ok((col_name.to_string(), label)) + }) + .collect::>()?; + + // Columns to preserve (not being unpivoted) + let keep_cols: Vec<&str> = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .filter(|name| !unpivot_cols.iter().any(|(c, _)| c == *name)) + .collect(); + + // Build UNION ALL: one SELECT per unpivot column + if unpivot_cols.is_empty() { + return Err(plan_datafusion_err!("UNPIVOT requires at least one column")); + } + + let mut union_inputs: Vec<_> = unpivot_cols + .iter() + .map(|(col_name, label)| { + let mut projection: Vec = keep_cols.iter().map(|c| col(*c)).collect(); + projection.push(lit(label.clone()).alias(&name_col_name)); + projection.push(col(col_name).alias(&value_col_name)); + + LogicalPlanBuilder::from(input.clone()) + .project(projection)? + .build() + }) + .collect::>()?; + + // Combine with UNION ALL + let mut plan = union_inputs.remove(0); + for branch in union_inputs { + plan = LogicalPlanBuilder::from(plan).union(branch)?.build()?; + } + + // Apply EXCLUDE NULLS filter (default behavior) + let exclude_nulls = null_inclusion.is_none() + || matches!(null_inclusion, Some(&NullInclusion::ExcludeNulls)); + if exclude_nulls { + plan = LogicalPlanBuilder::from(plan) + .filter(col(&value_col_name).is_not_null())? + .build()?; + } + + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) +} + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Extract column name from an expression. +fn extract_column_name(expr: &Expr) -> Result { + match expr { + Expr::Column(c) => Ok(c.name.clone()), + _ => Err(plan_datafusion_err!( + "Expected column reference, got {expr}" + )), + } +} + +/// Convert an expression to a string for use as column alias. +fn expr_to_string(expr: &Expr) -> String { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(), + Expr::Literal(v, _) => v.to_string(), + other => other.to_string(), + } +} diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs new file mode 100644 index 0000000000000..207fffe1327a3 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -0,0 +1,845 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # TABLESAMPLE Example +//! +//! This example demonstrates implementing SQL `TABLESAMPLE` support using +//! DataFusion's extensibility APIs. +//! +//! This is a working `TABLESAMPLE` implementation that can serve as a starting +//! point for your own projects. It also works as a template for adding other +//! custom SQL operators, covering the full pipeline from parsing to execution. +//! +//! It shows how to: +//! +//! 1. **Parse** TABLESAMPLE syntax via a custom [`RelationPlanner`] +//! 2. **Plan** sampling as a custom logical node ([`TableSamplePlanNode`]) +//! 3. **Execute** sampling via a custom physical operator ([`SampleExec`]) +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- Bernoulli sampling (each row has N% chance of selection) +//! SELECT * FROM table TABLESAMPLE BERNOULLI(10 PERCENT) +//! +//! -- Fractional sampling (0.0 to 1.0) +//! SELECT * FROM table TABLESAMPLE (0.1) +//! +//! -- Row count limit +//! SELECT * FROM table TABLESAMPLE (100 ROWS) +//! +//! -- Reproducible sampling with a seed +//! SELECT * FROM table TABLESAMPLE (10 PERCENT) REPEATABLE(42) +//! ``` +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SQL Query │ +//! │ SELECT * FROM t TABLESAMPLE BERNOULLI(10 PERCENT) REPEATABLE(1)│ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanner │ +//! │ (RelationPlanner: parses TABLESAMPLE, creates logical node) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanNode │ +//! │ (UserDefinedLogicalNode: stores sampling params) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSampleExtensionPlanner │ +//! │ (ExtensionPlanner: creates physical execution plan) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SampleExec │ +//! │ (ExecutionPlan: performs actual row sampling at runtime) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` + +use std::{ + any::Any, + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, + ops::{Add, Div, Mul, Sub}, + pin::Pin, + str::FromStr, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{ + array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, + compute, +}; +use arrow_schema::SchemaRef; +use futures::{ + ready, + stream::{Stream, StreamExt}, +}; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use tonic::async_trait; + +use datafusion::{ + execution::{ + RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder, + TaskContext, context::QueryPlanner, + }, + physical_expr::EquivalenceProperties, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, + prelude::*, +}; +use datafusion_common::{ + DFSchemaRef, DataFusionError, Result, Statistics, internal_err, not_impl_err, + plan_datafusion_err, plan_err, +}; +use datafusion_expr::{ + UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + logical_plan::{Extension, LogicalPlan, LogicalPlanBuilder}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{ + self, TableFactor, TableSampleMethod, TableSampleUnit, +}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the TABLESAMPLE examples demonstrating various sampling techniques. +pub async fn table_sample() -> Result<()> { + // Build session with custom query planner for physical planning + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner)) + .build(); + + let ctx = SessionContext::new_with_state(state); + + // Register custom relation planner for logical planning + ctx.register_relation_planner(Arc::new(TableSamplePlanner))?; + register_sample_data(&ctx)?; + + println!("TABLESAMPLE Example"); + println!("===================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Baseline - full table scan + let results = run_example( + ctx, + "Example 1: Full table (baseline)", + "SELECT * FROM sample_data", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + | 6 | row_6 | + | 7 | row_7 | + | 8 | row_8 | + | 9 | row_9 | + | 10 | row_10 | + +---------+---------+ + "); + + // Example 2: Percentage-based Bernoulli sampling + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 2: BERNOULLI percentage sampling", + "SELECT * FROM sample_data TABLESAMPLE BERNOULLI(30 PERCENT) REPEATABLE(123)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 7 | row_7 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 3: Fractional sampling (0.0 to 1.0) + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 3: Fractional sampling", + "SELECT * FROM sample_data TABLESAMPLE (0.5) REPEATABLE(456)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 2 | row_2 | + | 4 | row_4 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 4: Row count limit (deterministic, no seed needed) + let results = run_example( + ctx, + "Example 4: Row count limit", + "SELECT * FROM sample_data TABLESAMPLE (3 ROWS)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + +---------+---------+ + "); + + // Example 5: Sampling combined with filtering + let results = run_example( + ctx, + "Example 5: Sampling with WHERE clause", + "SELECT * FROM sample_data TABLESAMPLE (5 ROWS) WHERE column1 > 2", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + +---------+---------+ + "); + + // Example 6: Sampling in JOIN queries + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 6: Sampling in JOINs", + r#"SELECT t1.column1, t2.column1, t1.column2, t2.column2 + FROM sample_data t1 TABLESAMPLE (0.7) REPEATABLE(789) + JOIN sample_data t2 TABLESAMPLE (0.7) REPEATABLE(123) + ON t1.column1 = t2.column1"#, + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+---------+---------+ + | column1 | column1 | column2 | column2 | + +---------+---------+---------+---------+ + | 2 | 2 | row_2 | row_2 | + | 5 | 5 | row_5 | row_5 | + | 7 | 7 | row_7 | row_7 | + | 8 | 8 | row_8 | row_8 | + | 10 | 10 | row_10 | row_10 | + +---------+---------+---------+---------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data: 10 rows with column1=1..10 and column2="row_1".."row_10" +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + let column1: ArrayRef = Arc::new(Int32Array::from((1..=10).collect::>())); + let column2: ArrayRef = Arc::new(StringArray::from( + (1..=10).map(|i| format!("row_{i}")).collect::>(), + )); + let batch = + RecordBatch::try_from_iter(vec![("column1", column1), ("column2", column2)])?; + ctx.register_batch("sample_data", batch)?; + Ok(()) +} + +// ============================================================================ +// Logical Planning: TableSamplePlanner + TableSamplePlanNode +// ============================================================================ + +/// Relation planner that intercepts `TABLESAMPLE` clauses in SQL and creates +/// [`TableSamplePlanNode`] logical nodes. +#[derive(Debug)] +struct TableSamplePlanner; + +impl RelationPlanner for TableSamplePlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + // Only handle Table relations with TABLESAMPLE clause + let TableFactor::Table { + sample: Some(sample), + alias, + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + } = relation + else { + return Ok(RelationPlanning::Original(relation)); + }; + + // Extract sample spec (handles both before/after alias positions) + let sample = match sample { + ast::TableSampleKind::BeforeTableAlias(s) + | ast::TableSampleKind::AfterTableAlias(s) => s, + }; + + // Validate sampling method + if let Some(method) = &sample.name + && *method != TableSampleMethod::Bernoulli + && *method != TableSampleMethod::Row + { + return not_impl_err!( + "Sampling method {} is not supported (only BERNOULLI and ROW)", + method + ); + } + + // Offset sampling (ClickHouse-style) not supported + if sample.offset.is_some() { + return not_impl_err!( + "TABLESAMPLE with OFFSET is not supported (requires total row count)" + ); + } + + // Parse optional REPEATABLE seed + let seed = sample + .seed + .map(|s| { + s.value.to_string().parse::().map_err(|_| { + plan_datafusion_err!("REPEATABLE seed must be an integer") + }) + }) + .transpose()?; + + // Plan the underlying table without the sample clause + let base_relation = TableFactor::Table { + sample: None, + alias: alias.clone(), + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + }; + let input = context.plan(base_relation)?; + + // Handle bucket sampling (Hive-style: TABLESAMPLE(BUCKET x OUT OF y)) + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + return not_impl_err!( + "TABLESAMPLE BUCKET with ON clause requires CLUSTERED BY table" + ); + } + let bucket_num: u64 = + bucket.bucket.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket number must be an integer") + })?; + let total: u64 = + bucket.total.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket total must be an integer") + })?; + + let fraction = bucket_num as f64 / total as f64; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + return Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))); + } + + // Handle quantity-based sampling + let Some(quantity) = sample.quantity else { + return plan_err!( + "TABLESAMPLE requires a quantity (percentage, fraction, or row count)" + ); + }; + + match quantity.unit { + // TABLESAMPLE (N ROWS) - exact row limit + Some(TableSampleUnit::Rows) => { + let rows = parse_quantity::(&quantity.value)?; + if rows < 0 { + return plan_err!("row count must be non-negative, got {}", rows); + } + let plan = LogicalPlanBuilder::from(input) + .limit(0, Some(rows as usize))? + .build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + + // TABLESAMPLE (N PERCENT) - percentage sampling + Some(TableSampleUnit::Percent) => { + let percent = parse_quantity::(&quantity.value)?; + let fraction = percent / 100.0; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + + // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 + None => { + let value = parse_quantity::(&quantity.value)?; + if value < 0.0 { + return plan_err!("sample value must be non-negative, got {}", value); + } + let plan = if value >= 1.0 { + // Interpret as row limit + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()? + } else { + // Interpret as fraction + TableSamplePlanNode::new(input, value, seed).into_plan() + }; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + } + } +} + +/// Parse a SQL expression as a numeric value (supports basic arithmetic). +fn parse_quantity(expr: &ast::Expr) -> Result +where + T: FromStr + Add + Sub + Mul + Div, +{ + eval_numeric_expr(expr) + .ok_or_else(|| plan_datafusion_err!("invalid numeric expression: {:?}", expr)) +} + +/// Recursively evaluate numeric SQL expressions. +fn eval_numeric_expr(expr: &ast::Expr) -> Option +where + T: FromStr + Add + Sub + Mul + Div, +{ + match expr { + ast::Expr::Value(v) => match &v.value { + ast::Value::Number(n, _) => n.to_string().parse().ok(), + _ => None, + }, + ast::Expr::BinaryOp { left, op, right } => { + let l = eval_numeric_expr::(left)?; + let r = eval_numeric_expr::(right)?; + match op { + ast::BinaryOperator::Plus => Some(l + r), + ast::BinaryOperator::Minus => Some(l - r), + ast::BinaryOperator::Multiply => Some(l * r), + ast::BinaryOperator::Divide => Some(l / r), + _ => None, + } + } + _ => None, + } +} + +/// Custom logical plan node representing a TABLESAMPLE operation. +/// +/// Stores sampling parameters (bounds, seed) and wraps the input plan. +/// Gets converted to [`SampleExec`] during physical planning. +#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] +struct TableSamplePlanNode { + input: LogicalPlan, + lower_bound: HashableF64, + upper_bound: HashableF64, + seed: u64, +} + +impl TableSamplePlanNode { + /// Create a new sampling node with the given fraction (0.0 to 1.0). + fn new(input: LogicalPlan, fraction: f64, seed: Option) -> Self { + Self { + input, + lower_bound: HashableF64(0.0), + upper_bound: HashableF64(fraction), + seed: seed.unwrap_or_else(rand::random), + } + } + + /// Wrap this node in a LogicalPlan::Extension. + fn into_plan(self) -> LogicalPlan { + LogicalPlan::Extension(Extension { + node: Arc::new(self), + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Sample: bounds=[{}, {}], seed={}", + self.lower_bound.0, self.upper_bound.0, self.seed + ) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs.swap_remove(0), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + seed: self.seed, + }) + } +} + +/// Wrapper for f64 that implements Hash and Eq (required for LogicalPlan). +#[derive(Debug, Clone, Copy, PartialOrd)] +struct HashableF64(f64); + +impl PartialEq for HashableF64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for HashableF64 {} + +impl Hash for HashableF64 { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + +// ============================================================================ +// Physical Planning: TableSampleQueryPlanner + TableSampleExtensionPlanner +// ============================================================================ + +/// Custom query planner that registers [`TableSampleExtensionPlanner`] to +/// convert [`TableSamplePlanNode`] into [`SampleExec`]. +#[derive(Debug)] +struct TableSampleQueryPlanner; + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner, + )]); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Extension planner that converts [`TableSamplePlanNode`] to [`SampleExec`]. +struct TableSampleExtensionPlanner; + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + let Some(sample_node) = node.as_any().downcast_ref::() + else { + return Ok(None); + }; + + let exec = SampleExec::try_new( + Arc::clone(&physical_inputs[0]), + sample_node.lower_bound.0, + sample_node.upper_bound.0, + sample_node.seed, + )?; + Ok(Some(Arc::new(exec))) + } +} + +// ============================================================================ +// Physical Execution: SampleExec + BernoulliSampler +// ============================================================================ + +/// Physical execution plan that samples rows from its input using Bernoulli sampling. +/// +/// Each row is independently selected with probability `(upper_bound - lower_bound)` +/// and appears at most once. +#[derive(Debug, Clone)] +pub struct SampleExec { + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + metrics: ExecutionPlanMetricsSet, + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with Bernoulli sampling (without replacement). + /// + /// # Arguments + /// * `input` - The input execution plan + /// * `lower_bound` - Lower bound of sampling range (typically 0.0) + /// * `upper_bound` - Upper bound of sampling range (0.0 to 1.0) + /// * `seed` - Random seed for reproducible sampling + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must satisfy 0.0 <= lower <= upper <= 1.0, got [{}, {}]", + lower_bound, + upper_bound + ); + } + + let cache = PlanProperties::new( + EquivalenceProperties::new(input.schema()), + input.properties().partitioning.clone(), + input.properties().emission_type, + input.properties().boundedness, + ); + + Ok(Self { + input, + lower_bound, + upper_bound, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + /// Create a sampler for the given partition. + fn create_sampler(&self, partition: usize) -> BernoulliSampler { + let seed = self.seed.wrapping_add(partition as u64); + BernoulliSampler::new(self.lower_bound, self.upper_bound, seed) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!( + f, + "SampleExec: bounds=[{}, {}], seed={}", + self.lower_bound, self.upper_bound, self.seed + ) + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn maintains_input_order(&self) -> Vec { + // Sampling preserves row order (rows are filtered, not reordered) + vec![true] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::try_new( + children.swap_remove(0), + self.lower_bound, + self.upper_bound, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(SampleStream { + input: self.input.execute(partition, context)?, + sampler: self.create_sampler(partition), + metrics: BaselineMetrics::new(&self.metrics, partition), + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let mut stats = self.input.partition_statistics(partition)?; + let ratio = self.upper_bound - self.lower_bound; + + // Scale statistics by sampling ratio (inexact due to randomness) + stats.num_rows = stats + .num_rows + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + + Ok(stats) + } +} + +/// Bernoulli sampler: includes each row with probability `(upper - lower)`. +/// This is sampling **without replacement** - each row appears at most once. +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } + + fn sample(&mut self, batch: &RecordBatch) -> Result { + let range = self.upper_bound - self.lower_bound; + if range <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + // Select rows where random value falls in [lower, upper) + let indices: Vec = (0..batch.num_rows()) + .filter(|_| { + let r: f64 = self.rng.random(); + r >= self.lower_bound && r < self.upper_bound + }) + .map(|i| i as u32) + .collect(); + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + compute::take_record_batch(batch, &UInt32Array::from(indices)) + .map_err(DataFusionError::from) + } +} + +/// Stream adapter that applies sampling to each batch. +struct SampleStream { + input: SendableRecordBatchStream, + sampler: BernoulliSampler, + metrics: BaselineMetrics, +} + +impl Stream for SampleStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let elapsed = self.metrics.elapsed_compute().clone(); + let _timer = elapsed.timer(); + let result = self.sampler.sample(&batch); + Poll::Ready(Some(result.record_output(&self.metrics))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for SampleStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs deleted file mode 100644 index 20b515506f3b4..0000000000000 --- a/datafusion-examples/examples/sql_dialect.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::fmt::Display; - -use datafusion::error::{DataFusionError, Result}; -use datafusion::sql::{ - parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, tokenizer::Token}, -}; - -/// This example demonstrates how to use the DFParser to parse a statement in a custom way -/// -/// This technique can be used to implement a custom SQL dialect, for example. -#[tokio::main] -async fn main() -> Result<()> { - let mut my_parser = - MyParser::new("COPY source_table TO 'file.fasta' STORED AS FASTA")?; - - let my_statement = my_parser.parse_statement()?; - - match my_statement { - MyStatement::DFStatement(s) => println!("df: {s}"), - MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), - } - - Ok(()) -} - -/// Here we define a Parser for our new SQL dialect that wraps the existing `DFParser` -struct MyParser<'a> { - df_parser: DFParser<'a>, -} - -impl<'a> MyParser<'a> { - fn new(sql: &'a str) -> Result { - let df_parser = DFParserBuilder::new(sql).build()?; - Ok(Self { df_parser }) - } - - /// Returns true if the next token is `COPY` keyword, false otherwise - fn is_copy(&self) -> bool { - matches!( - self.df_parser.parser.peek_token().token, - Token::Word(w) if w.keyword == Keyword::COPY - ) - } - - /// This is the entry point to our parser -- it handles `COPY` statements specially - /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { - if self.is_copy() { - self.df_parser.parser.next_token(); // COPY - let df_statement = self.df_parser.parse_copy()?; - - if let Statement::CopyTo(s) = df_statement { - Ok(MyStatement::from(s)) - } else { - Ok(MyStatement::DFStatement(Box::from(df_statement))) - } - } else { - let df_statement = self.df_parser.parse_statement()?; - Ok(MyStatement::from(df_statement)) - } - } -} - -enum MyStatement { - DFStatement(Box), - MyCopyTo(MyCopyToStatement), -} - -impl Display for MyStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MyStatement::DFStatement(s) => write!(f, "{s}"), - MyStatement::MyCopyTo(s) => write!(f, "{s}"), - } - } -} - -impl From for MyStatement { - fn from(s: Statement) -> Self { - Self::DFStatement(Box::from(s)) - } -} - -impl From for MyStatement { - fn from(s: CopyToStatement) -> Self { - if s.stored_as == Some("FASTA".to_string()) { - Self::MyCopyTo(MyCopyToStatement::from(s)) - } else { - Self::DFStatement(Box::from(Statement::CopyTo(s))) - } - } -} - -struct MyCopyToStatement { - pub source: CopyToSource, - pub target: String, -} - -impl From for MyCopyToStatement { - fn from(s: CopyToStatement) -> Self { - Self { - source: s.source, - target: s.target, - } - } -} - -impl Display for MyCopyToStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "COPY {} TO '{}' STORED AS FASTA", - self.source, self.target - ) - } -} diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_ops/analysis.rs similarity index 98% rename from datafusion-examples/examples/sql_analysis.rs rename to datafusion-examples/examples/sql_ops/analysis.rs index 4ff669faf1d0c..4243a2927865b 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_ops/analysis.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to use the structures that DataFusion provides to perform //! Analysis on SQL queries and their plans. //! @@ -23,8 +25,8 @@ use std::sync::Arc; -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::Result; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::logical_expr::LogicalPlan; use datafusion::{ datasource::MemTable, @@ -32,141 +34,9 @@ use datafusion::{ }; use test_utils::tpcds::tpcds_schemas; -/// Counts the total number of joins in a plan -fn total_join_count(plan: &LogicalPlan) -> usize { - let mut total = 0; - - // We can use the TreeNode API to walk over a LogicalPlan. - plan.apply(|node| { - // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - } - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - total -} - -/// Counts the total number of joins in a plan and collects every join tree in -/// the plan with their respective join count. -/// -/// Join Tree Definition: the largest subtree consisting entirely of joins -/// -/// For example, this plan: -/// -/// ```text -/// JOIN -/// / \ -/// A JOIN -/// / \ -/// B C -/// ``` -/// -/// has a single join tree `(A-B-C)` which will result in `(2, [2])` -/// -/// This plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` -fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { - // this works the same way as `total_count`, but now when we encounter a Join - // we try to collect it's entire tree - let mut to_visit = vec![plan]; - let mut total = 0; - let mut groups = vec![]; - - while let Some(node) = to_visit.pop() { - // if we encounter a join, we know were at the root of the tree - // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_)) { - let (group_count, inputs) = count_tree(node); - total += group_count; - groups.push(group_count); - to_visit.extend(inputs); - } else { - to_visit.extend(node.inputs()); - } - } - - (total, groups) -} - -/// Count the entire join tree and return its inputs using TreeNode API -/// -/// For example, if this function receives following plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// It will return `(1, [A, GROUP])` -fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { - let mut inputs = Vec::new(); - let mut total = 0; - - join.apply(|node| { - // Some extra knowledge: - // - // optimized plans have their projections pushed down as far as - // possible, which sometimes results in a projection going in between 2 - // subsequent joins giving the illusion these joins are not "related", - // when in fact they are. - // - // This plan: - // JOIN - // / \ - // A PROJECTION - // | - // JOIN - // / \ - // B C - // - // is the same as: - // - // JOIN - // / \ - // A JOIN - // / \ - // B C - // we can continue the recursion in this case - if let LogicalPlan::Projection(_) = node { - return Ok(TreeNodeRecursion::Continue); - } - - // any join we count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - Ok(TreeNodeRecursion::Continue) - } else { - inputs.push(node); - // skip children of input node - Ok(TreeNodeRecursion::Jump) - } - }) - .unwrap(); - - (total, inputs) -} - -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates how to analyze a SQL query by counting JOINs and identifying +/// join-trees using DataFusion’s `LogicalPlan` and `TreeNode` API. +pub async fn analysis() -> Result<()> { // To show how we can count the joins in a sql query we'll be using query 88 // from the TPC-DS benchmark. // @@ -310,3 +180,136 @@ from Ok(()) } + +/// Counts the total number of joins in a plan +fn total_join_count(plan: &LogicalPlan) -> usize { + let mut total = 0; + + // We can use the TreeNode API to walk over a LogicalPlan. + plan.apply(|node| { + // if we encounter a join we update the running count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + total +} + +/// Counts the total number of joins in a plan and collects every join tree in +/// the plan with their respective join count. +/// +/// Join Tree Definition: the largest subtree consisting entirely of joins +/// +/// For example, this plan: +/// +/// ```text +/// JOIN +/// / \ +/// A JOIN +/// / \ +/// B C +/// ``` +/// +/// has a single join tree `(A-B-C)` which will result in `(2, [2])` +/// +/// This plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` +fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { + // this works the same way as `total_count`, but now when we encounter a Join + // we try to collect it's entire tree + let mut to_visit = vec![plan]; + let mut total = 0; + let mut groups = vec![]; + + while let Some(node) = to_visit.pop() { + // if we encounter a join, we know were at the root of the tree + // count this tree and recurse on it's inputs + if matches!(node, LogicalPlan::Join(_)) { + let (group_count, inputs) = count_tree(node); + total += group_count; + groups.push(group_count); + to_visit.extend(inputs); + } else { + to_visit.extend(node.inputs()); + } + } + + (total, groups) +} + +/// Count the entire join tree and return its inputs using TreeNode API +/// +/// For example, if this function receives following plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// It will return `(1, [A, GROUP])` +fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { + let mut inputs = Vec::new(); + let mut total = 0; + + join.apply(|node| { + // Some extra knowledge: + // + // optimized plans have their projections pushed down as far as + // possible, which sometimes results in a projection going in between 2 + // subsequent joins giving the illusion these joins are not "related", + // when in fact they are. + // + // This plan: + // JOIN + // / \ + // A PROJECTION + // | + // JOIN + // / \ + // B C + // + // is the same as: + // + // JOIN + // / \ + // A JOIN + // / \ + // B C + // we can continue the recursion in this case + if let LogicalPlan::Projection(_) = node { + return Ok(TreeNodeRecursion::Continue); + } + + // any join we count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + Ok(TreeNodeRecursion::Continue) + } else { + inputs.push(node); + // skip children of input node + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + (total, inputs) +} diff --git a/datafusion-examples/examples/sql_ops/custom_sql_parser.rs b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs new file mode 100644 index 0000000000000..308a0de62a242 --- /dev/null +++ b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs @@ -0,0 +1,420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates extending the DataFusion SQL parser to support +//! custom DDL statements, specifically `CREATE EXTERNAL CATALOG`. +//! +//! ### Custom Syntax +//! ```sql +//! CREATE EXTERNAL CATALOG my_catalog +//! STORED AS ICEBERG +//! LOCATION 's3://my-bucket/warehouse/' +//! OPTIONS ( +//! 'region' = 'us-west-2' +//! ); +//! ``` +//! +//! Note: For the purpose of this example, we use `local://workspace/` to +//! automatically discover and register files from the project's test data. + +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; + +use datafusion::catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, + TableProviderFactory, +}; +use datafusion::datasource::listing_table_factory::ListingTableFactory; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::SessionContext; +use datafusion::sql::{ + parser::{DFParser, DFParserBuilder, Statement}, + sqlparser::{ + ast::{ObjectName, Value}, + keywords::Keyword, + tokenizer::Token, + }, +}; +use datafusion_common::{DFSchema, TableReference, plan_datafusion_err, plan_err}; +use datafusion_expr::CreateExternalTable; +use futures::StreamExt; +use insta::assert_snapshot; +use object_store::ObjectStore; +use object_store::local::LocalFileSystem; + +/// Entry point for the example. +pub async fn custom_sql_parser() -> Result<()> { + // Use standard Parquet testing data as our "external" source. + let base_path = datafusion::common::test_util::parquet_test_data(); + let base_path = std::path::Path::new(&base_path).canonicalize()?; + + // Make the path relative to the workspace root + let workspace_root = workspace_root(); + let location = base_path + .strip_prefix(&workspace_root) + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| base_path.to_string_lossy().to_string()); + + let create_catalog_sql = format!( + "CREATE EXTERNAL CATALOG parquet_testing + STORED AS parquet + LOCATION 'local://workspace/{location}' + OPTIONS ( + 'schema_name' = 'staged_data', + 'format.pruning' = 'true' + )" + ); + + // ========================================================================= + // Part 1: Standard DataFusion parser rejects the custom DDL + // ========================================================================= + println!("=== Part 1: Standard DataFusion Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx_standard = SessionContext::new(); + let err = ctx_standard + .sql(&create_catalog_sql) + .await + .expect_err("Expected the standard parser to reject CREATE EXTERNAL CATALOG (custom DDL syntax)"); + + println!("Error: {err}\n"); + assert_snapshot!(err.to_string(), @r#"SQL error: ParserError("Expected: TABLE, found: CATALOG at Line: 1, Column: 17")"#); + + // ========================================================================= + // Part 2: Custom parser handles the statement + // ========================================================================= + println!("=== Part 2: Custom Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx = SessionContext::new(); + + let mut parser = CustomParser::new(&create_catalog_sql)?; + let statement = parser.parse_statement()?; + match statement { + CustomStatement::CreateExternalCatalog(stmt) => { + handle_create_external_catalog(&ctx, stmt).await?; + } + CustomStatement::DFStatement(_) => { + panic!("Expected CreateExternalCatalog statement"); + } + } + + // Query a table from the registered catalog + let query_sql = "SELECT id, bool_col, tinyint_col FROM parquet_testing.staged_data.alltypes_plain LIMIT 5"; + println!("Executing: {query_sql}\n"); + + let results = execute_sql(&ctx, query_sql).await?; + println!("{results}"); + assert_snapshot!(results, @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + +----+----------+-------------+ + "); + + Ok(()) +} + +/// Execute SQL and return formatted results. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result { + let batches = ctx.sql(sql).await?.collect().await?; + Ok(arrow::util::pretty::pretty_format_batches(&batches)?.to_string()) +} + +/// Custom handler for the `CREATE EXTERNAL CATALOG` statement. +async fn handle_create_external_catalog( + ctx: &SessionContext, + stmt: CreateExternalCatalog, +) -> Result<()> { + let factory = ListingTableFactory::new(); + let catalog = Arc::new(MemoryCatalogProvider::new()); + let schema = Arc::new(MemorySchemaProvider::new()); + + // Extract options + let mut schema_name = "public".to_string(); + let mut table_options = HashMap::new(); + + for (k, v) in stmt.options { + let val_str = match v { + Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s) => { + s.to_string() + } + Value::Number(ref n, _) => n.to_string(), + Value::Boolean(b) => b.to_string(), + _ => v.to_string(), + }; + + if k == "schema_name" { + schema_name = val_str; + } else { + table_options.insert(k, val_str); + } + } + + println!(" Target Catalog: {}", stmt.name); + println!(" Data Location: {}", stmt.location); + println!(" Resolved Schema: {schema_name}"); + + // Register a local object store rooted at the workspace root. + // We use a specific authority 'workspace' to ensure consistent resolution. + let store = Arc::new(LocalFileSystem::new_with_prefix(workspace_root())?); + let store_url = url::Url::parse("local://workspace").unwrap(); + ctx.register_object_store(&store_url, Arc::clone(&store) as _); + + let target_ext = format!(".{}", stmt.catalog_type.to_lowercase()); + + // For 'local://workspace/parquet-testing/data', the path is 'parquet-testing/data'. + let path_str = stmt + .location + .strip_prefix("local://workspace/") + .unwrap_or(&stmt.location); + let prefix = object_store::path::Path::from(path_str); + + // Discover data files using the ObjectStore API + let mut table_count = 0; + let mut list_stream = store.list(Some(&prefix)); + + while let Some(meta) = list_stream.next().await { + let meta = meta?; + let path = &meta.location; + + if path.as_ref().ends_with(&target_ext) { + let name = std::path::Path::new(path.as_ref()) + .file_stem() + .unwrap() + .to_string_lossy() + .to_string(); + + let table_url = format!("local://workspace/{path}"); + + let cmd = CreateExternalTable::builder( + TableReference::bare(name.clone()), + table_url, + stmt.catalog_type.clone(), + Arc::new(DFSchema::empty()), + ) + .with_options(table_options.clone()) + .build(); + + match factory.create(&ctx.state(), &cmd).await { + Ok(table) => { + schema.register_table(name, table)?; + table_count += 1; + } + Err(e) => { + eprintln!("Failed to create table {name}: {e}"); + } + } + } + } + println!(" Registered {table_count} tables into schema: {schema_name}"); + + catalog.register_schema(&schema_name, schema)?; + ctx.register_catalog(stmt.name.to_string(), catalog); + + Ok(()) +} + +/// Possible statements returned by our custom parser. +#[derive(Debug, Clone)] +pub enum CustomStatement { + /// Standard DataFusion statement + DFStatement(Box), + /// Custom `CREATE EXTERNAL CATALOG` statement + CreateExternalCatalog(CreateExternalCatalog), +} + +/// Data structure for `CREATE EXTERNAL CATALOG`. +#[derive(Debug, Clone)] +pub struct CreateExternalCatalog { + pub name: ObjectName, + pub catalog_type: String, + pub location: String, + pub options: Vec<(String, Value)>, +} + +impl Display for CustomStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::DFStatement(s) => write!(f, "{s}"), + Self::CreateExternalCatalog(s) => write!(f, "{s}"), + } + } +} + +impl Display for CreateExternalCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CREATE EXTERNAL CATALOG {} STORED AS {} LOCATION '{}'", + self.name, self.catalog_type, self.location + )?; + if !self.options.is_empty() { + write!(f, " OPTIONS (")?; + for (i, (k, v)) in self.options.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "'{k}' = '{v}'")?; + } + write!(f, ")")?; + } + Ok(()) + } +} + +/// A parser that extends `DFParser` with custom syntax. +struct CustomParser<'a> { + df_parser: DFParser<'a>, +} + +impl<'a> CustomParser<'a> { + fn new(sql: &'a str) -> Result { + Ok(Self { + df_parser: DFParserBuilder::new(sql).build()?, + }) + } + + pub fn parse_statement(&mut self) -> Result { + if self.is_create_external_catalog() { + return self.parse_create_external_catalog(); + } + Ok(CustomStatement::DFStatement(Box::new( + self.df_parser.parse_statement()?, + ))) + } + + fn is_create_external_catalog(&self) -> bool { + let t1 = &self.df_parser.parser.peek_nth_token(0).token; + let t2 = &self.df_parser.parser.peek_nth_token(1).token; + let t3 = &self.df_parser.parser.peek_nth_token(2).token; + + matches!(t1, Token::Word(w) if w.keyword == Keyword::CREATE) + && matches!(t2, Token::Word(w) if w.keyword == Keyword::EXTERNAL) + && matches!(t3, Token::Word(w) if w.value.to_uppercase() == "CATALOG") + } + + fn parse_create_external_catalog(&mut self) -> Result { + // Consume prefix tokens: CREATE EXTERNAL CATALOG + for _ in 0..3 { + self.df_parser.parser.next_token(); + } + + let name = self + .df_parser + .parser + .parse_object_name(false) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let mut catalog_type = None; + let mut location = None; + let mut options = vec![]; + + while let Some(keyword) = self.df_parser.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::LOCATION, + Keyword::OPTIONS, + ]) { + match keyword { + Keyword::STORED => { + if catalog_type.is_some() { + return plan_err!("Duplicate STORED AS"); + } + self.df_parser + .parser + .expect_keyword(Keyword::AS) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + catalog_type = Some( + self.df_parser + .parser + .parse_identifier() + .map_err(|e| DataFusionError::External(Box::new(e)))? + .value, + ); + } + Keyword::LOCATION => { + if location.is_some() { + return plan_err!("Duplicate LOCATION"); + } + location = Some( + self.df_parser + .parser + .parse_literal_string() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + ); + } + Keyword::OPTIONS => { + if !options.is_empty() { + return plan_err!("Duplicate OPTIONS"); + } + options = self.parse_value_options()?; + } + _ => unreachable!(), + } + } + + Ok(CustomStatement::CreateExternalCatalog( + CreateExternalCatalog { + name, + catalog_type: catalog_type + .ok_or_else(|| plan_datafusion_err!("Missing STORED AS"))?, + location: location + .ok_or_else(|| plan_datafusion_err!("Missing LOCATION"))?, + options, + }, + )) + } + + /// Parse options in the form: (key [=] value, key [=] value, ...) + fn parse_value_options(&mut self) -> Result> { + let mut options = vec![]; + self.df_parser + .parser + .expect_token(&Token::LParen) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + loop { + let key = self.df_parser.parse_option_key()?; + // Support optional '=' between key and value + let _ = self.df_parser.parser.consume_token(&Token::Eq); + let value = self.df_parser.parse_option_value()?; + options.push((key, value)); + + let comma = self.df_parser.parser.consume_token(&Token::Comma); + if self.df_parser.parser.consume_token(&Token::RParen) { + break; + } else if !comma { + return plan_err!("Expected ',' or ')' in OPTIONS"); + } + } + Ok(options) + } +} + +/// Returns the workspace root directory (parent of datafusion-examples). +fn workspace_root() -> std::path::PathBuf { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .expect("CARGO_MANIFEST_DIR should have a parent") + .to_path_buf() +} diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs similarity index 98% rename from datafusion-examples/examples/sql_frontend.rs rename to datafusion-examples/examples/sql_ops/frontend.rs index 1fc9ce24ecbb5..025fe47e75b07 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::common::{plan_err, TableReference}; +use datafusion::common::{TableReference, plan_err}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -44,7 +46,7 @@ use std::sync::Arc; /// /// In this example, we demonstrate how to use the lower level APIs directly, /// which only requires the `datafusion-sql` dependency. -pub fn main() -> Result<()> { +pub fn frontend() -> Result<()> { // First, we parse the SQL string. Note that we use the DataFusion // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion // specific syntax such as `CREATE EXTERNAL TABLE` diff --git a/datafusion-examples/examples/sql_ops/main.rs b/datafusion-examples/examples/sql_ops/main.rs new file mode 100644 index 0000000000000..8c3ac056698b7 --- /dev/null +++ b/datafusion-examples/examples/sql_ops/main.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # SQL Examples +//! +//! These examples demonstrate SQL operations in DataFusion. +//! +//! ## Usage +//! ```bash +//! cargo run --example sql_ops -- [all|analysis|custom_sql_parser|frontend|query] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! - `analysis` — analyse SQL queries with DataFusion structures +//! - `custom_sql_parser` — implementing a custom SQL parser to extend DataFusion +//! - `frontend` — create LogicalPlans (only) from sql strings +//! - `query` — query data using SQL (in memory RecordBatches, local Parquet files) + +mod analysis; +mod custom_sql_parser; +mod frontend; +mod query; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Analysis, + CustomSqlParser, + Frontend, + Query, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "sql_ops"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Analysis => analysis::analysis().await?, + ExampleKind::CustomSqlParser => { + custom_sql_parser::custom_sql_parser().await? + } + ExampleKind::Frontend => frontend::frontend()?, + ExampleKind::Query => query::query().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/sql_query.rs b/datafusion-examples/examples/sql_ops/query.rs similarity index 97% rename from datafusion-examples/examples/sql_query.rs rename to datafusion-examples/examples/sql_ops/query.rs index 4da07d33d03d4..90d0c3ca34a00 100644 --- a/datafusion-examples/examples/sql_query.rs +++ b/datafusion-examples/examples/sql_ops/query.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::array::{UInt64Array, UInt8Array}; +//! See `main.rs` for how to run it. + +use datafusion::arrow::array::{UInt8Array, UInt64Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::{assert_batches_eq, exec_datafusion_err}; +use datafusion::datasource::MemTable; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; -use datafusion::datasource::MemTable; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::*; use object_store::local::LocalFileSystem; @@ -32,8 +34,7 @@ use std::sync::Arc; /// /// [`query_memtable`]: a simple query against a [`MemTable`] /// [`query_parquet`]: a simple query against a directory with multiple Parquet files -#[tokio::main] -async fn main() -> Result<()> { +pub async fn query() -> Result<()> { query_memtable().await?; query_parquet().await?; Ok(()) @@ -152,7 +153,8 @@ async fn query_parquet() -> Result<()> { "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", ], - &results); + &results + ); // Second example were we temporarily move into the test data's parent directory and // simulate a relative path, this requires registering an ObjectStore. @@ -201,7 +203,8 @@ async fn query_parquet() -> Result<()> { "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", ], - &results); + &results + ); // Reset the current directory std::env::set_current_dir(cur_dir)?; diff --git a/datafusion-examples/examples/udf/advanced_udaf.rs b/datafusion-examples/examples/udf/advanced_udaf.rs index 81e227bfacee4..fbb9e652486ce 100644 --- a/datafusion-examples/examples/udf/advanced_udaf.rs +++ b/datafusion-examples/examples/udf/advanced_udaf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{Field, Schema}; use datafusion::physical_expr::NullState; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; @@ -26,13 +28,13 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; -use datafusion::common::{cast::as_float64_array, ScalarValue}; +use datafusion::common::{ScalarValue, cast::as_float64_array}; use datafusion::error::Result; use datafusion::logical_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, expr::AggregateFunction, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, simplify::SimplifyInfo, - Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, }; use datafusion::prelude::*; diff --git a/datafusion-examples/examples/udf/advanced_udf.rs b/datafusion-examples/examples/udf/advanced_udf.rs index bb5a68e90cbbe..a00a7e7df434f 100644 --- a/datafusion-examples/examples/udf/advanced_udf.rs +++ b/datafusion-examples/examples/udf/advanced_udf.rs @@ -15,19 +15,21 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, AsArray, Float32Array, Float64Array, + Array, ArrayRef, AsArray, Float32Array, Float64Array, new_null_array, }; use arrow::compute; use arrow::datatypes::{DataType, Float64Type}; use arrow::record_batch::RecordBatch; -use datafusion::common::{exec_err, internal_err, ScalarValue}; +use datafusion::common::{ScalarValue, exec_err, internal_err}; use datafusion::error::Result; -use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::Volatility; +use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }; diff --git a/datafusion-examples/examples/udf/advanced_udwf.rs b/datafusion-examples/examples/udf/advanced_udwf.rs index 86f215e019c78..e8d3a75b29dec 100644 --- a/datafusion-examples/examples/udf/advanced_udwf.rs +++ b/datafusion-examples/examples/udf/advanced_udwf.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::{any::Any, fs::File, io::Write, sync::Arc}; + use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, @@ -36,8 +40,7 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use std::any::Any; -use std::sync::Arc; +use tempfile::tempdir; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -227,12 +230,46 @@ async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + // content from file 'datafusion/core/tests/data/cars.csv' + let csv_data = r#"car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 +"#; + let dir = tempdir()?; + let file_path = dir.path().join("cars.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all(csv_data.as_bytes())?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); + + ctx.register_csv("cars", file_path, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } diff --git a/datafusion-examples/examples/udf/async_udf.rs b/datafusion-examples/examples/udf/async_udf.rs index 475775a599f62..c31e8290ccce5 100644 --- a/datafusion-examples/examples/udf/async_udf.rs +++ b/datafusion-examples/examples/udf/async_udf.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to create and use "Async UDFs" in DataFusion. //! //! Async UDFs allow you to perform asynchronous operations, such as //! making network requests. This can be used for tasks like fetching //! data from an external API such as a LLM service or an external database. +use std::{any::Any, sync::Arc}; + use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; @@ -35,8 +39,6 @@ use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use std::any::Any; -use std::sync::Arc; /// In this example we register `AskLLM` as an asynchronous user defined function /// and invoke it via the DataFrame API and SQL @@ -91,20 +93,19 @@ pub async fn async_udf() -> Result<()> { assert_batches_eq!( [ - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", - "| logical_plan | SubqueryAlias: a |", - "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", - "| | TableScan: animal projection=[id, name] |", - "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |", - "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", - "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", - "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | DataSourceExec: partitions=1, partition_sizes=[1] |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | SubqueryAlias: a |", + "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", + "| | TableScan: animal projection=[id, name] |", + "| physical_plan | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", + "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", + "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | DataSourceExec: partitions=1, partition_sizes=[1] |", + "| | |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", ], &results ); diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs index ba36dbb15c58b..0fb26ff5f74ce 100644 --- a/datafusion-examples/examples/udf/main.rs +++ b/datafusion-examples/examples/udf/main.rs @@ -19,7 +19,13 @@ //! //! These examples demonstrate user-defined functions in DataFusion. //! +//! ## Usage +//! ```bash +//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf] +//! ``` +//! //! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module //! - `adv_udaf` — user defined aggregate function example //! - `adv_udf` — user defined scalar function example //! - `adv_udwf` — user defined window function example @@ -38,11 +44,14 @@ mod simple_udf; mod simple_udtf; mod simple_udwf; -use std::str::FromStr; - use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] enum ExampleKind { + All, AdvUdaf, AdvUdf, AdvUdwf, @@ -53,55 +62,32 @@ enum ExampleKind { Udtf, } -impl AsRef for ExampleKind { - fn as_ref(&self) -> &str { - match self { - Self::AdvUdaf => "adv_udaf", - Self::AdvUdf => "adv_udf", - Self::AdvUdwf => "adv_udwf", - Self::AsyncUdf => "async_udf", - Self::Udf => "udf", - Self::Udaf => "udaf", - Self::Udwf => "udwt", - Self::Udtf => "udtf", - } - } -} - -impl FromStr for ExampleKind { - type Err = DataFusionError; +impl ExampleKind { + const EXAMPLE_NAME: &str = "udf"; - fn from_str(s: &str) -> Result { - match s { - "adv_udaf" => Ok(Self::AdvUdaf), - "adv_udf" => Ok(Self::AdvUdf), - "adv_udwf" => Ok(Self::AdvUdwf), - "async_udf" => Ok(Self::AsyncUdf), - "udaf" => Ok(Self::Udaf), - "udf" => Ok(Self::Udf), - "udtf" => Ok(Self::Udtf), - "udwf" => Ok(Self::Udwf), - _ => Err(DataFusionError::Execution(format!("Unknown example: {s}"))), - } + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) } -} -impl ExampleKind { - const ALL: [Self; 8] = [ - Self::AdvUdaf, - Self::AdvUdf, - Self::AdvUdwf, - Self::AsyncUdf, - Self::Udaf, - Self::Udf, - Self::Udtf, - Self::Udwf, - ]; - - const EXAMPLE_NAME: &str = "udf"; + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AdvUdaf => advanced_udaf::advanced_udaf().await?, + ExampleKind::AdvUdf => advanced_udf::advanced_udf().await?, + ExampleKind::AdvUdwf => advanced_udwf::advanced_udwf().await?, + ExampleKind::AsyncUdf => async_udf::async_udf().await?, + ExampleKind::Udaf => simple_udaf::simple_udaf().await?, + ExampleKind::Udf => simple_udf::simple_udf().await?, + ExampleKind::Udtf => simple_udtf::simple_udtf().await?, + ExampleKind::Udwf => simple_udwf::simple_udwf().await?, + } - fn variants() -> Vec<&'static str> { - Self::ALL.iter().map(|x| x.as_ref()).collect() + Ok(()) } } @@ -110,24 +96,14 @@ async fn main() -> Result<()> { let usage = format!( "Usage: cargo run --example {} -- [{}]", ExampleKind::EXAMPLE_NAME, - ExampleKind::variants().join("|") + ExampleKind::VARIANTS.join("|") ); - let arg = std::env::args().nth(1).ok_or_else(|| { - eprintln!("{usage}"); - DataFusionError::Execution("Missing argument".to_string()) - })?; - - match arg.parse::()? { - ExampleKind::AdvUdaf => advanced_udaf::advanced_udaf().await?, - ExampleKind::AdvUdf => advanced_udf::advanced_udf().await?, - ExampleKind::AdvUdwf => advanced_udwf::advanced_udwf().await?, - ExampleKind::AsyncUdf => async_udf::async_udf().await?, - ExampleKind::Udaf => simple_udaf::simple_udaf().await?, - ExampleKind::Udf => simple_udf::simple_udf().await?, - ExampleKind::Udtf => simple_udtf::simple_udtf().await?, - ExampleKind::Udwf => simple_udwf::simple_udwf().await?, - } + let example: ExampleKind = std::env::args() + .nth(1) + .ok_or_else(|| DataFusionError::Execution(format!("Missing argument. {usage}")))? + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; - Ok(()) + example.run().await } diff --git a/datafusion-examples/examples/udf/simple_udaf.rs b/datafusion-examples/examples/udf/simple_udaf.rs index e9f905e720997..42ea0054b759f 100644 --- a/datafusion-examples/examples/udf/simple_udaf.rs +++ b/datafusion-examples/examples/udf/simple_udaf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean use datafusion::arrow::{ diff --git a/datafusion-examples/examples/udf/simple_udf.rs b/datafusion-examples/examples/udf/simple_udf.rs index 7d4f3588e313f..e8d6c9c8173ac 100644 --- a/datafusion-examples/examples/udf/simple_udf.rs +++ b/datafusion-examples/examples/udf/simple_udf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::{ arrow::{ array::{ArrayRef, Float32Array, Float64Array}, diff --git a/datafusion-examples/examples/udf/simple_udtf.rs b/datafusion-examples/examples/udf/simple_udtf.rs index a03b157134aea..087b8ba73af5c 100644 --- a/datafusion-examples/examples/udf/simple_udtf.rs +++ b/datafusion-examples/examples/udf/simple_udtf.rs @@ -15,16 +15,18 @@ // specific language governing permissions and limitations // under the License. -use arrow::csv::reader::Format; +//! See `main.rs` for how to run it. + use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::Session; use datafusion::catalog::TableFunctionImpl; -use datafusion::common::{plan_err, ScalarValue}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::common::{ScalarValue, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::simplify::SimplifyContext; @@ -132,8 +134,7 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() - else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(path)), _)) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; diff --git a/datafusion-examples/examples/udf/simple_udwf.rs b/datafusion-examples/examples/udf/simple_udwf.rs index 2cf1df8d8ed86..1842d88b9ba29 100644 --- a/datafusion-examples/examples/udf/simple_udwf.rs +++ b/datafusion-examples/examples/udf/simple_udwf.rs @@ -15,29 +15,65 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +//! See `main.rs` for how to run it. + +use std::{fs::File, io::Write, sync::Arc}; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::{DataType, Float64Type}, }; - use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::{PartitionEvaluator, Volatility, WindowFrame}; use datafusion::prelude::*; +use tempfile::tempdir; // create local execution context with `cars.csv` registered as a table named `cars` async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + // content from file 'datafusion/core/tests/data/cars.csv' + let csv_data = r#"car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 +"#; + let dir = tempdir()?; + let file_path = dir.path().join("cars.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all(csv_data.as_bytes())?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); + + ctx.register_csv("cars", file_path, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 4b802c0067e59..be1374b371485 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -46,7 +46,6 @@ futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -tokio = { workspace = true } [dev-dependencies] datafusion-datasource-parquet = { workspace = true } diff --git a/datafusion/catalog-listing/src/config.rs b/datafusion/catalog-listing/src/config.rs index 3370d2ea75535..ca4d2abfcd737 100644 --- a/datafusion/catalog-listing/src/config.rs +++ b/datafusion/catalog-listing/src/config.rs @@ -19,9 +19,10 @@ use crate::options::ListingOptions; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::{config_err, internal_err}; +use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_compression_type::FileCompressionType; +#[expect(deprecated)] use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::ListingTableUrl; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use std::str::FromStr; use std::sync::Arc; @@ -44,15 +45,12 @@ pub enum SchemaSource { /// # Schema Evolution Support /// /// This configuration supports schema evolution through the optional -/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: +/// [`PhysicalExprAdapterFactory`]. You might want to override the default factory when you need: /// /// - **Type coercion requirements**: When you need custom logic for converting between /// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) /// - **Column mapping**: You need to map columns with a legacy name to a new name /// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. -/// -/// If not specified, a [`datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory`] -/// will be used, which handles basic schema compatibility cases. #[derive(Debug, Clone, Default)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating [`crate::ListingTable`]. @@ -68,8 +66,6 @@ pub struct ListingTableConfig { pub options: Option, /// Tracks the source of the schema information pub(crate) schema_source: SchemaSource, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - pub(crate) schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters pub(crate) expr_adapter_factory: Option>, } @@ -218,8 +214,7 @@ impl ListingTableConfig { file_schema, options: _, schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, } = self; let (schema, new_schema_source) = match file_schema { @@ -241,8 +236,7 @@ impl ListingTableConfig { file_schema: Some(schema), options: Some(options), schema_source: new_schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, }) } None => internal_err!("No `ListingOptions` set for inferring schema"), @@ -282,7 +276,6 @@ impl ListingTableConfig { file_schema: self.file_schema, options: Some(options), schema_source: self.schema_source, - schema_adapter_factory: self.schema_adapter_factory, expr_adapter_factory: self.expr_adapter_factory, }) } @@ -290,63 +283,11 @@ impl ListingTableConfig { } } - /// Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] - /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. - /// - /// If not provided, a default schema adapter factory will be used. - /// - /// # Example: Custom Schema Adapter for Type Coercion - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # - /// # #[derive(Debug)] - /// # struct MySchemaAdapterFactory; - /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { - /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { - /// # unimplemented!() - /// # } - /// # } - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(listing_options) - /// .with_schema(table_schema) - /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); - /// ``` - pub fn with_schema_adapter_factory( - self, - schema_adapter_factory: Arc, - ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this configuration - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set the [`PhysicalExprAdapterFactory`] for the [`crate::ListingTable`] /// /// The expression adapter factory is used to create physical expression adapters that can /// handle schema evolution and type conversions when evaluating expressions /// with different schemas than the table schema. - /// - /// If not provided, a default physical expression adapter factory will be used unless a custom - /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. - /// - /// See for details on this transition. pub fn with_expr_adapter_factory( self, expr_adapter_factory: Arc, @@ -356,4 +297,23 @@ impl ListingTableConfig { ..self } } + + /// Deprecated: Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] + /// + /// `SchemaAdapterFactory` has been removed. Use [`Self::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. + /// + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn with_schema_adapter_factory( + self, + _schema_adapter_factory: Arc, + ) -> Self { + // No-op - just return self unchanged + self + } } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939e..ea016015cebd3 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -21,25 +21,23 @@ use std::mem; use std::sync::Arc; use datafusion_catalog::Session; -use datafusion_common::internal_err; -use datafusion_common::{HashMap, Result, ScalarValue}; +use datafusion_common::{HashMap, Result, ScalarValue, assert_or_internal_err}; use datafusion_datasource::ListingTableUrl; use datafusion_datasource::PartitionedFile; -use datafusion_expr::{BinaryExpr, Operator}; +use datafusion_expr::{BinaryExpr, Operator, lit, utils}; use arrow::{ - array::{Array, ArrayRef, AsArray, StringBuilder}, - compute::{and, cast, prep_null_mask_filter}, - datatypes::{DataType, Field, Fields, Schema}, + array::AsArray, + datatypes::{DataType, Field}, record_batch::RecordBatch, }; use datafusion_expr::execution_props::ExecutionProps; use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, DataFusionError}; +use datafusion_common::{Column, DFSchema}; use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; @@ -53,7 +51,7 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { + Expr::Column(Column { name, .. }) => { is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) @@ -239,105 +237,6 @@ pub async fn list_partitions( Ok(out) } -async fn prune_partitions( - table_path: &ListingTableUrl, - partitions: Vec, - filters: &[Expr], - partition_cols: &[(String, DataType)], -) -> Result> { - if filters.is_empty() { - // prune partitions which don't contain the partition columns - return Ok(partitions - .into_iter() - .filter(|p| { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - !parse_partitions_for_path(table_path, &p.path, cols) - .unwrap_or_default() - .is_empty() - }) - .collect()); - } - - let mut builders: Vec<_> = (0..partition_cols.len()) - .map(|_| StringBuilder::with_capacity(partitions.len(), partitions.len() * 10)) - .collect(); - - for partition in &partitions { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols) - .unwrap_or_default(); - - let mut builders = builders.iter_mut(); - for (p, b) in parsed.iter().zip(&mut builders) { - b.append_value(p); - } - builders.for_each(|b| b.append_null()); - } - - let arrays = partition_cols - .iter() - .zip(builders) - .map(|((_, d), mut builder)| { - let array = builder.finish(); - cast(&array, d) - }) - .collect::>()?; - - let fields: Fields = partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - let df_schema = DFSchema::from_unqualified_fields( - partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(), - Default::default(), - )?; - - let batch = RecordBatch::try_new(schema, arrays)?; - - // TODO: Plumb this down - let props = ExecutionProps::new(); - - // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Result { - let expr = create_physical_expr(filter, &df_schema, &props)?; - expr.evaluate(&batch)?.into_array(partitions.len()) - }; - - //.Compute the conjunction of the filters - let mask = filters - .iter() - .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) - .reduce(|a, b| Ok(and(&a?, &b?)?)); - - let mask = match mask { - Some(Ok(mask)) => mask, - Some(Err(err)) => return Err(err), - None => return Ok(partitions), - }; - - // Don't retain partitions that evaluated to null - let prepared = match mask.null_count() { - 0 => mask, - _ => prep_null_mask_filter(&mask), - }; - - // Sanity check - assert_eq!(prepared.len(), partitions.len()); - - let filtered = partitions - .into_iter() - .zip(prepared.values()) - .filter_map(|(p, f)| f.then_some(p)) - .collect(); - - Ok(filtered) -} - #[derive(Debug)] enum PartitionValue { Single(String), @@ -348,16 +247,11 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter - { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) - | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -412,6 +306,62 @@ pub fn evaluate_partition_prefix<'a>( } } +fn filter_partitions( + pf: PartitionedFile, + filters: &[Expr], + df_schema: &DFSchema, +) -> Result> { + if pf.partition_values.is_empty() && !filters.is_empty() { + return Ok(None); + } else if filters.is_empty() { + return Ok(Some(pf)); + } + + let arrays = pf + .partition_values + .iter() + .map(|v| v.to_array()) + .collect::>()?; + + let batch = RecordBatch::try_new(Arc::clone(df_schema.inner()), arrays)?; + + let filter = utils::conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)); + let props = ExecutionProps::new(); + let expr = create_physical_expr(&filter, df_schema, &props)?; + + // Since we're only operating on a single file, our batch and resulting "array" holds only one + // value indicating if the input file matches the provided filters + let matches = expr.evaluate(&batch)?.into_array(1)?; + if matches.as_boolean().value(0) { + return Ok(Some(pf)); + } + + Ok(None) +} + +fn try_into_partitioned_file( + object_meta: ObjectMeta, + partition_cols: &[(String, DataType)], + table_path: &ListingTableUrl, +) -> Result { + let cols = partition_cols.iter().map(|(name, _)| name.as_str()); + let parsed = parse_partitions_for_path(table_path, &object_meta.location, cols); + + let partition_values = parsed + .into_iter() + .flatten() + .zip(partition_cols) + .map(|(parsed, (_, datatype))| { + ScalarValue::try_from_string(parsed.to_string(), datatype) + }) + .collect::>>()?; + + let mut pf: PartitionedFile = object_meta.into(); + pf.partition_values = partition_values; + + Ok(pf) +} + /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. /// `filters` should only contain expressions that can be evaluated @@ -424,80 +374,46 @@ pub async fn pruned_partition_list<'a>( file_extension: &'a str, partition_cols: &'a [(String, DataType)], ) -> Result>> { - // if no partition col => simply list all the files - if partition_cols.is_empty() { - if !filters.is_empty() { - return internal_err!( - "Got partition filters for unpartitioned table {}", - table_path - ); - } - return Ok(Box::pin( - table_path - .list_all_files(ctx, store, file_extension) - .await? - .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)) - .map_ok(|object_meta| object_meta.into()), - )); - } - - let partition_prefix = evaluate_partition_prefix(partition_cols, filters); - - let partitions = - list_partitions(store, table_path, partition_cols.len(), partition_prefix) - .await?; - debug!("Listed {} partitions", partitions.len()); + let prefix = if !partition_cols.is_empty() { + evaluate_partition_prefix(partition_cols, filters) + } else { + None + }; - let pruned = - prune_partitions(table_path, partitions, filters, partition_cols).await?; + let objects = table_path + .list_prefixed_files(ctx, store, prefix, file_extension) + .await? + .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)); - debug!("Pruning yielded {} partitions", pruned.len()); + if partition_cols.is_empty() { + assert_or_internal_err!( + filters.is_empty(), + "Got partition filters for unpartitioned table {}", + table_path + ); - let stream = futures::stream::iter(pruned) - .map(move |partition: Partition| async move { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols); + // if no partition col => simply list all the files + Ok(objects.map_ok(|object_meta| object_meta.into()).boxed()) + } else { + let df_schema = DFSchema::from_unqualified_fields( + partition_cols + .iter() + .map(|(n, d)| Field::new(n, d.clone(), true)) + .collect(), + Default::default(), + )?; - let partition_values = parsed - .into_iter() - .flatten() - .zip(partition_cols) - .map(|(parsed, (_, datatype))| { - ScalarValue::try_from_string(parsed.to_string(), datatype) - }) - .collect::>>()?; - - let files = match partition.files { - Some(files) => files, - None => { - trace!("Recursively listing partition {}", partition.path); - store.list(Some(&partition.path)).try_collect().await? - } - }; - let files = files.into_iter().filter(move |o| { - let extension_match = o.location.as_ref().ends_with(file_extension); - // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) - let glob_match = table_path.contains(&o.location, false); - extension_match && glob_match - }); - - let stream = futures::stream::iter(files.map(move |object_meta| { - Ok(PartitionedFile { - object_meta, - partition_values: partition_values.clone(), - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) - })); - - Ok::<_, DataFusionError>(stream) - }) - .buffer_unordered(CONCURRENCY_LIMIT) - .try_flatten() - .boxed(); - Ok(stream) + Ok(objects + .map_ok(|object_meta| { + try_into_partitioned_file(object_meta, partition_cols, table_path) + }) + .try_filter_map(move |pf| { + futures::future::ready( + pf.and_then(|pf| filter_partitions(pf, filters, &df_schema)), + ) + }) + .boxed()) + } } /// Extract the partition values for the given `file_path` (in the given `table_path`) @@ -541,22 +457,11 @@ pub fn describe_partition(partition: &Partition) -> (&str, usize, Vec<&str>) { #[cfg(test)] mod tests { - use async_trait::async_trait; - use datafusion_common::config::TableOptions; use datafusion_datasource::file_groups::FileGroup; - use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::RuntimeEnv; - use futures::FutureExt; - use object_store::memory::InMemory; - use std::any::Any; use std::ops::Not; use super::*; - use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, - }; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_plan::ExecutionPlan; + use datafusion_expr::{Expr, case, col, lit}; #[test] fn test_split_files() { @@ -599,209 +504,6 @@ mod tests { assert_eq!(0, chunks.len()); } - #[tokio::test] - async fn test_pruned_partition_list_empty() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/notparquetfile", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/file.parquet", 100), - ("tablepath/notapartition/file.parquet", 100), - ("tablepath/notmypartition=val1/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .collect::>() - .await; - - assert_eq!(pruned.len(), 0); - } - - #[tokio::test] - async fn test_pruned_partition_list() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/file.parquet", 100), - ("tablepath/mypartition=val2/file.parquet", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/mypartition=val1/other=val3/file.parquet", 100), - ("tablepath/notapartition/file.parquet", 100), - ("tablepath/notmypartition=val1/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/mypartition=val1/file.parquet" - ); - assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/mypartition=val1/other=val3/file.parquet" - ); - assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); - } - - #[tokio::test] - async fn test_pruned_partition_list_multi() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), - ]); - let filter1 = Expr::eq(col("part1"), lit("p1v2")); - let filter2 = Expr::eq(col("part2"), lit("p2v1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2], - ".parquet", - &[ - (String::from("part1"), DataType::Utf8), - (String::from("part2"), DataType::Utf8), - ], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file1.parquet" - ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] - ); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file2.parquet" - ); - assert_eq!( - &f2.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] - ); - } - - #[tokio::test] - async fn test_list_partition() { - let (store, _) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), - ]); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 0, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec![]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 1, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 2, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ( - "tablepath/part1=p1v2/part2=p2v1", - 2, - vec!["file1.parquet", "file2.parquet"] - ), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), - ] - ); - } - #[test] fn test_parse_partitions_for_path() { assert_eq!( @@ -1016,86 +718,4 @@ mod tests { Some(Path::from("a=1970-01-05")), ); } - - pub fn make_test_store_and_state( - files: &[(&str, u64)], - ) -> (Arc, Arc) { - let memory = InMemory::new(); - - for (name, size) in files { - memory - .put(&Path::from(*name), vec![0; *size as usize].into()) - .now_or_never() - .unwrap() - .unwrap(); - } - - (Arc::new(memory), Arc::new(MockSession {})) - } - - struct MockSession {} - - #[async_trait] - impl Session for MockSession { - fn session_id(&self) -> &str { - unimplemented!() - } - - fn config(&self) -> &SessionConfig { - unimplemented!() - } - - async fn create_physical_plan( - &self, - _logical_plan: &LogicalPlan, - ) -> Result> { - unimplemented!() - } - - fn create_physical_expr( - &self, - _expr: Expr, - _df_schema: &DFSchema, - ) -> Result> { - unimplemented!() - } - - fn scalar_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn aggregate_functions( - &self, - ) -> &std::collections::HashMap> { - unimplemented!() - } - - fn window_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn runtime_env(&self) -> &Arc { - unimplemented!() - } - - fn execution_props(&self) -> &ExecutionProps { - unimplemented!() - } - - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - - fn table_options(&self) -> &TableOptions { - unimplemented!() - } - - fn table_options_mut(&mut self) -> &mut TableOptions { - unimplemented!() - } - - fn task_ctx(&self) -> Arc { - unimplemented!() - } - } } diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 90d04b46b8067..28bd880ea01fb 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#![deny(clippy::allow_attributes)] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -31,4 +33,4 @@ mod table; pub use config::{ListingTableConfig, SchemaSource}; pub use options::ListingOptions; -pub use table::ListingTable; +pub use table::{ListFilesResult, ListingTable}; diff --git a/datafusion/catalog-listing/src/options.rs b/datafusion/catalog-listing/src/options.rs index 7da8005f90ec2..146f98d62335e 100644 --- a/datafusion/catalog-listing/src/options.rs +++ b/datafusion/catalog-listing/src/options.rs @@ -18,12 +18,12 @@ use arrow::datatypes::{DataType, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::plan_err; -use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::file_format::FileFormat; use datafusion_execution::config::SessionConfig; use datafusion_expr::SortExpr; use futures::StreamExt; -use futures::{future, TryStreamExt}; +use futures::{TryStreamExt, future}; use itertools::Itertools; use std::sync::Arc; diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs index 95f9523d4401c..9fb2dd2dce29c 100644 --- a/datafusion/catalog-listing/src/table.rs +++ b/datafusion/catalog-listing/src/table.rs @@ -23,18 +23,16 @@ use async_trait::async_trait; use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_datafusion_err, plan_err, project_schema, Constraints, DataFusionError, - SchemaExt, Statistics, + Constraints, SchemaExt, Statistics, internal_datafusion_err, plan_err, project_schema, }; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::file_sink_config::FileSinkConfig; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, -}; +#[expect(deprecated)] +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_datasource::{ - compute_all_files_statistics, ListingTableUrl, PartitionedFile, + ListingTableUrl, PartitionedFile, TableSchema, compute_all_files_statistics, }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; @@ -44,14 +42,25 @@ use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::create_lex_ordering; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::ExecutionPlan; -use futures::{future, stream, Stream, StreamExt, TryStreamExt}; +use datafusion_physical_plan::empty::EmptyExec; +use futures::{Stream, StreamExt, TryStreamExt, future, stream}; use object_store::ObjectStore; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +/// Result of a file listing operation from [`ListingTable::list_files_for_scan`]. +#[derive(Debug)] +pub struct ListFilesResult { + /// File groups organized by the partitioning strategy. + pub file_groups: Vec, + /// Aggregated statistics for all files. + pub statistics: Statistics, + /// Whether files are grouped by partition values (enables Hash partitioning). + pub grouped_by_partition: bool, +} + /// Built in [`TableProvider`] that reads data from one or more files as a single table. /// /// The files are read using an [`ObjectStore`] instance, for example from @@ -178,13 +187,11 @@ pub struct ListingTable { /// The SQL definition for this table, if any definition: Option, /// Cache for collected file statistics - collected_statistics: FileStatisticsCache, + collected_statistics: Arc, /// Constraints applied to this table constraints: Constraints, /// Column default expressions for columns that are not physically present in the data files column_defaults: HashMap, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters expr_adapter_factory: Option>, } @@ -227,7 +234,6 @@ impl ListingTable { collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), constraints: Constraints::default(), column_defaults: HashMap::new(), - schema_adapter_factory: config.schema_adapter_factory, expr_adapter_factory: config.expr_adapter_factory, }; @@ -255,7 +261,7 @@ impl ListingTable { /// multiple times in the same session. /// /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. - pub fn with_cache(mut self, cache: Option) -> Self { + pub fn with_cache(mut self, cache: Option>) -> Self { self.collected_statistics = cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); self @@ -282,71 +288,52 @@ impl ListingTable { self.schema_source } - /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// Deprecated: Set the [`SchemaAdapterFactory`] for this [`ListingTable`] /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. + /// `SchemaAdapterFactory` has been removed. Use [`ListingTableConfig::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. /// - /// # Example: Adding Schema Evolution Support - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTable, ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); - /// # let table = ListingTable::try_new(config).unwrap(); - /// let table_with_evolution = table - /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); - /// ``` - /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use ListingTableConfig::with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] pub fn with_schema_adapter_factory( self, - schema_adapter_factory: Arc, + _schema_adapter_factory: Arc, ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this table - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() + // No-op - just return self unchanged + self } - /// Creates a schema adapter for mapping between file and table schemas + /// Deprecated: Returns the [`SchemaAdapterFactory`] used by this [`ListingTable`]. /// - /// Uses the configured schema adapter factory if available, otherwise falls back - /// to the default implementation. - fn create_schema_adapter(&self) -> Box { - let table_schema = self.schema(); - match &self.schema_adapter_factory { - Some(factory) => { - factory.create_with_projected_schema(Arc::clone(&table_schema)) - } - None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), - } + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + /// + /// Always returns `None`. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn schema_adapter_factory(&self) -> Option> { + None } - /// Creates a file source and applies schema adapter factory if available - fn create_file_source_with_schema_adapter( - &self, - ) -> datafusion_common::Result> { - let mut source = self.options.format.file_source(); - // Apply schema adapter to source if available - // - // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. - // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. - if let Some(factory) = &self.schema_adapter_factory { - source = source.with_schema_adapter_factory(Arc::clone(factory))?; - } - Ok(source) + /// Creates a file source for this table + fn create_file_source(&self) -> Arc { + let table_schema = TableSchema::new( + Arc::clone(&self.file_schema), + self.options + .table_partition_cols + .iter() + .map(|(col, field)| Arc::new(Field::new(col, field.clone(), false))) + .collect(), + ); + + self.options.format.file_source(table_schema) } /// If file_sort_order is specified, creates the appropriate physical expressions @@ -418,7 +405,7 @@ impl TableProvider for ListingTable { .options .table_partition_cols .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .map(|col| Ok(Arc::new(self.table_schema.field_with_name(&col.0)?.clone()))) .collect::>>()?; let table_partition_col_names = table_partition_cols @@ -437,7 +424,11 @@ impl TableProvider for ListingTable { // at the same time. This is because the limit should be applied after the filters are applied. let statistic_file_limit = if filters.is_empty() { limit } else { None }; - let (mut partitioned_file_lists, statistics) = self + let ListFilesResult { + file_groups: mut partitioned_file_lists, + statistics, + grouped_by_partition: partitioned_by_file_group, + } = self .list_files_for_scan(state, &partition_filters, statistic_file_limit) .await?; @@ -469,7 +460,9 @@ impl TableProvider for ListingTable { if new_groups.len() <= self.options.target_partitions { partitioned_file_lists = new_groups; } else { - log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + log::debug!( + "attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered" + ) } } None => {} // no ordering required @@ -483,7 +476,7 @@ impl TableProvider for ListingTable { ))))); }; - let file_source = self.create_file_source_with_schema_adapter()?; + let file_source = self.create_file_source(); // create the execution plan let plan = self @@ -491,20 +484,16 @@ impl TableProvider for ListingTable { .format .create_physical_plan( state, - FileScanConfigBuilder::new( - object_store_url, - Arc::clone(&self.file_schema), - file_source, - ) - .with_file_groups(partitioned_file_lists) - .with_constraints(self.constraints.clone()) - .with_statistics(statistics) - .with_projection_indices(projection) - .with_limit(limit) - .with_output_ordering(output_ordering) - .with_table_partition_cols(table_partition_cols) - .with_expr_adapter(self.expr_adapter_factory.clone()) - .build(), + FileScanConfigBuilder::new(object_store_url, file_source) + .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) + .with_statistics(statistics) + .with_projection_indices(projection)? + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_expr_adapter(self.expr_adapter_factory.clone()) + .with_partitioned_by_file_group(partitioned_by_file_group) + .build(), ) .await?; @@ -574,6 +563,11 @@ impl TableProvider for ListingTable { let keep_partition_by_columns = state.config_options().execution.keep_partition_by_columns; + // Invalidate cache entries for this table if they exist + if let Some(lfc) = state.runtime_env().cache_manager.get_list_files_cache() { + let _ = lfc.remove(table_path.prefix()); + } + // Sink related option, apart from format let config = FileSinkConfig { original_url: String::default(), @@ -611,11 +605,15 @@ impl ListingTable { ctx: &'a dyn Session, filters: &'a [Expr], limit: Option, - ) -> datafusion_common::Result<(Vec, Statistics)> { + ) -> datafusion_common::Result { let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { - return Ok((vec![], Statistics::new_unknown(&self.file_schema))); + return Ok(ListFilesResult { + file_groups: vec![], + statistics: Statistics::new_unknown(&self.file_schema), + grouped_by_partition: false, + }); }; // list files (with partitions) let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { @@ -649,27 +647,51 @@ impl ListingTable { let (file_group, inexact_stats) = get_files_with_limit(files, limit, self.options.collect_stat).await?; - let file_groups = file_group.split_files(self.options.target_partitions); - let (mut file_groups, mut stats) = compute_all_files_statistics( + // Threshold: 0 = disabled, N > 0 = enabled when distinct_keys >= N + // + // When enabled, files are grouped by their Hive partition column values, allowing + // FileScanConfig to declare Hash partitioning. This enables the optimizer to skip + // hash repartitioning for aggregates and joins on partition columns. + let threshold = ctx.config_options().optimizer.preserve_file_partitions; + + let (file_groups, grouped_by_partition) = if threshold > 0 + && !self.options.table_partition_cols.is_empty() + { + let grouped = + file_group.group_by_partition_values(self.options.target_partitions); + if grouped.len() >= threshold { + (grouped, true) + } else { + let all_files: Vec<_> = + grouped.into_iter().flat_map(|g| g.into_inner()).collect(); + ( + FileGroup::new(all_files).split_files(self.options.target_partitions), + false, + ) + } + } else { + ( + file_group.split_files(self.options.target_partitions), + false, + ) + }; + + let (file_groups, stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, )?; - let schema_adapter = self.create_schema_adapter(); - let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; - - stats.column_statistics = - schema_mapper.map_column_statistics(&stats.column_statistics)?; - file_groups.iter_mut().try_for_each(|file_group| { - if let Some(stat) = file_group.statistics_mut() { - stat.column_statistics = - schema_mapper.map_column_statistics(&stat.column_statistics)?; - } - Ok::<_, DataFusionError>(()) - })?; - Ok((file_groups, stats)) + // Note: Statistics already include both file columns and partition columns. + // PartitionedFile::with_statistics automatically appends exact partition column + // statistics (min=max=partition_value, null_count=0, distinct_count=1) computed + // from partition_values. + Ok(ListFilesResult { + file_groups, + statistics: stats, + grouped_by_partition, + }) } /// Collects statistics for a given partitioned file. @@ -756,28 +778,25 @@ async fn get_files_with_limit( let file = file_result?; // Update file statistics regardless of state - if collect_stats { - if let Some(file_stats) = &file.statistics { - num_rows = if file_group.is_empty() { - // For the first file, just take its row count - file_stats.num_rows - } else { - // For subsequent files, accumulate the counts - num_rows.add(&file_stats.num_rows) - }; - } + if collect_stats && let Some(file_stats) = &file.statistics { + num_rows = if file_group.is_empty() { + // For the first file, just take its row count + file_stats.num_rows + } else { + // For subsequent files, accumulate the counts + num_rows.add(&file_stats.num_rows) + }; } // Always add the file to our group file_group.push(file); // Check if we've hit the limit (if one was specified) - if let Some(limit) = limit { - if let Precision::Exact(row_count) = num_rows { - if row_count > limit { - state = ProcessingState::ReachedLimit; - } - } + if let Some(limit) = limit + && let Precision::Exact(row_count) = num_rows + && row_count > limit + { + state = ProcessingState::ReachedLimit; } } // If we still have files in the stream, it means that the limit kicked diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs index 1c830c976d8b8..1b8039d828fdb 100644 --- a/datafusion/catalog/src/async.rs +++ b/datafusion/catalog/src/async.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::{error::Result, not_impl_err, HashMap, TableReference}; +use datafusion_common::{HashMap, TableReference, error::Result, not_impl_err}; use datafusion_execution::config::SessionConfig; use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; @@ -60,7 +60,9 @@ impl SchemaProvider for ResolvedSchemaProvider { } fn deregister_table(&self, name: &str) -> Result>> { - not_impl_err!("Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported") + not_impl_err!( + "Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported" + ) } fn table_exist(&self, name: &str) -> bool { @@ -193,7 +195,7 @@ impl CatalogProviderList for ResolvedCatalogProviderList { /// /// See the [remote_catalog.rs] for an end to end example /// -/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs #[async_trait] pub trait AsyncSchemaProvider: Send + Sync { /// Lookup a table in the schema provider @@ -425,14 +427,14 @@ mod tests { use std::{ any::Any, sync::{ - atomic::{AtomicU32, Ordering}, Arc, + atomic::{AtomicU32, Ordering}, }, }; use arrow::datatypes::SchemaRef; use async_trait::async_trait; - use datafusion_common::{error::Result, Statistics, TableReference}; + use datafusion_common::{Statistics, TableReference, error::Result}; use datafusion_execution::config::SessionConfig; use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::ExecutionPlan; diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 71b9eccf9d657..bb9e89eba2fef 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -20,8 +20,8 @@ use std::fmt::Debug; use std::sync::Arc; pub use crate::schema::SchemaProvider; -use datafusion_common::not_impl_err; use datafusion_common::Result; +use datafusion_common::not_impl_err; /// Represents a catalog, comprising a number of named schemas. /// @@ -61,7 +61,7 @@ use datafusion_common::Result; /// schemas and tables exist. /// /// [Delta Lake]: https://delta.io/ -/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs /// /// The [`CatalogProvider`] can support this use case, but it takes some care. /// The planning APIs in DataFusion are not `async` and thus network IO can not @@ -100,7 +100,7 @@ use datafusion_common::Result; /// /// [`datafusion-cli`]: https://datafusion.apache.org/user-guide/cli/index.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 -/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/catalog.rs +/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/catalog.rs /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// diff --git a/datafusion/catalog/src/cte_worktable.rs b/datafusion/catalog/src/cte_worktable.rs index d6b2a453118c9..9565dcc60141e 100644 --- a/datafusion/catalog/src/cte_worktable.rs +++ b/datafusion/catalog/src/cte_worktable.rs @@ -17,20 +17,18 @@ //! CteWorkTable implementation used for recursive queries +use std::any::Any; +use std::borrow::Cow; use std::sync::Arc; -use std::{any::Any, borrow::Cow}; -use crate::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_physical_plan::work_table::WorkTableExec; - -use datafusion_physical_plan::ExecutionPlan; - use datafusion_common::error::Result; use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::work_table::WorkTableExec; -use crate::TableProvider; +use crate::{ScanArgs, ScanResult, Session, TableProvider}; /// The temporary working table where the previous iteration of a recursive query is stored /// Naming is based on PostgreSQL's implementation. @@ -85,16 +83,28 @@ impl TableProvider for CteWorkTable { async fn scan( &self, - _state: &dyn Session, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, ) -> Result> { - // TODO: pushdown filters and limits - Ok(Arc::new(WorkTableExec::new( + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + _state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + Ok(ScanResult::new(Arc::new(WorkTableExec::new( self.name.clone(), Arc::clone(&self.table_schema), - ))) + args.projection().map(|p| p.to_vec()), + )?))) } fn supports_filters_pushdown( diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 11963c06c88f5..fb6531ba0b2ee 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -23,7 +23,7 @@ use std::{any::Any, borrow::Cow}; use crate::TableProvider; use arrow::datatypes::SchemaRef; -use datafusion_common::{internal_err, Constraints}; +use datafusion_common::{Constraints, internal_err}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; /// Implements [`TableSource`] for a [`TableProvider`] diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index d733551f44051..52bfeca3d4282 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -28,16 +28,17 @@ use arrow::{ record_batch::RecordBatch, }; use async_trait::async_trait; +use datafusion_common::DataFusionError; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; use datafusion_common::types::NativeType; -use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use datafusion_expr::{TableType, Volatility}; +use datafusion_physical_plan::SendableRecordBatchStream; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::SendableRecordBatchStream; use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -137,11 +138,11 @@ impl InformationSchemaConfig { let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); for schema_name in catalog.schema_names() { - if schema_name != INFORMATION_SCHEMA { - if let Some(schema) = catalog.schema(&schema_name) { - let schema_owner = schema.owner_name(); - builder.add_schemata(&catalog_name, &schema_name, schema_owner); - } + if schema_name != INFORMATION_SCHEMA + && let Some(schema) = catalog.schema(&schema_name) + { + let schema_owner = schema.owner_name(); + builder.add_schemata(&catalog_name, &schema_name, schema_owner); } } } @@ -215,11 +216,16 @@ impl InformationSchemaConfig { fn make_df_settings( &self, config_options: &ConfigOptions, + runtime_env: &Arc, builder: &mut InformationSchemaDfSettingsBuilder, ) { for entry in config_options.entries() { builder.add_setting(entry); } + // Add runtime configuration entries + for entry in runtime_env.config_entries() { + builder.add_setting(entry); + } } fn make_routines( @@ -245,7 +251,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udf.signature()), - return_type, + return_type.as_ref(), "SCALAR", udf.documentation().map(|d| d.description.to_string()), udf.documentation().map(|d| d.syntax_example.to_string()), @@ -265,7 +271,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udaf.signature()), - return_type, + return_type.as_ref(), "AGGREGATE", udaf.documentation().map(|d| d.description.to_string()), udaf.documentation().map(|d| d.syntax_example.to_string()), @@ -285,7 +291,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udwf.signature()), - return_type, + return_type.as_ref(), "WINDOW", udwf.documentation().map(|d| d.description.to_string()), udwf.documentation().map(|d| d.syntax_example.to_string()), @@ -418,11 +424,11 @@ fn get_udf_args_and_return_types( // only handle the function which implemented [`ScalarUDFImpl::return_type`] method let return_type = udf .return_type(&arg_types) - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -445,10 +451,10 @@ fn get_udaf_args_and_return_types( let return_type = udaf .return_type(&arg_types) .ok() - .map(|t| remove_native_type_prefix(NativeType::from(t))); + .map(|t| remove_native_type_prefix(&NativeType::from(t))); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -470,7 +476,7 @@ fn get_udwf_args_and_return_types( // only handle the function which implemented [`ScalarUDFImpl::return_type`] method let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, None) }) @@ -479,7 +485,7 @@ fn get_udwf_args_and_return_types( } #[inline] -fn remove_native_type_prefix(native_type: NativeType) -> String { +fn remove_native_type_prefix(native_type: &NativeType) -> String { format!("{native_type}") } @@ -679,7 +685,7 @@ impl InformationSchemaViewBuilder { catalog_name: impl AsRef, schema_name: impl AsRef, table_name: impl AsRef, - definition: Option>, + definition: Option<&(impl AsRef + ?Sized)>, ) { // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); @@ -1060,7 +1066,12 @@ impl PartitionStream for InformationSchemaDfSettings { // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables - config.make_df_settings(ctx.session_config().options(), &mut builder); + let runtime_env = ctx.runtime_env(); + config.make_df_settings( + ctx.session_config().options(), + &runtime_env, + &mut builder, + ); Ok(builder.finish()) }), )) @@ -1156,7 +1167,7 @@ struct InformationSchemaRoutinesBuilder { } impl InformationSchemaRoutinesBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_routine( &mut self, catalog_name: impl AsRef, @@ -1164,7 +1175,7 @@ impl InformationSchemaRoutinesBuilder { routine_name: impl AsRef, routine_type: impl AsRef, is_deterministic: bool, - data_type: Option>, + data_type: Option<&impl AsRef>, function_type: impl AsRef, description: Option>, syntax_example: Option>, @@ -1290,7 +1301,7 @@ struct InformationSchemaParametersBuilder { } impl InformationSchemaParametersBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_parameter( &mut self, specific_catalog: impl AsRef, @@ -1298,7 +1309,7 @@ impl InformationSchemaParametersBuilder { specific_name: impl AsRef, ordinal_position: u64, parameter_mode: impl AsRef, - parameter_name: Option>, + parameter_name: Option<&(impl AsRef + ?Sized)>, data_type: impl AsRef, parameter_default: Option>, is_variadic: bool, @@ -1397,7 +1408,9 @@ mod tests { // InformationSchemaConfig::make_tables used this before `table_type` // existed but should not, as it may be expensive. async fn table(&self, _: &str) -> Result>> { - panic!("InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type") + panic!( + "InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type" + ) } fn as_any(&self) -> &dyn Any { diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 1c5e38438724e..d1cd3998fecf1 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -23,6 +23,8 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] +#![deny(clippy::allow_attributes)] //! Interfaces and default implementations of catalogs and schemas. //! @@ -46,13 +48,13 @@ mod dynamic_file; mod schema; mod table; +pub use r#async::*; pub use catalog::*; pub use datafusion_session::Session; pub use dynamic_file::catalog::*; pub use memory::{ MemTable, MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider, }; -pub use r#async::*; pub use schema::*; pub use table::*; diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index af96cfc15fc82..77fbea8577089 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -26,7 +26,7 @@ use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; use datafusion_common::{ - internal_datafusion_err, DFSchema, DataFusionError, HashMap, TableReference, + DFSchema, DataFusionError, HashMap, TableReference, internal_datafusion_err, }; use datafusion_expr::CreateExternalTable; @@ -127,22 +127,13 @@ impl ListingSchemaProvider { .factory .create( state, - &CreateExternalTable { - schema: Arc::new(DFSchema::empty()), + &CreateExternalTable::builder( name, - location: table_url, - file_type: self.format.clone(), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: Default::default(), - constraints: Default::default(), - column_defaults: Default::default(), - }, + table_url, + self.format.clone(), + Arc::new(DFSchema::empty()), + ) + .build(), ) .await?; let _ = diff --git a/datafusion/catalog/src/memory/schema.rs b/datafusion/catalog/src/memory/schema.rs index f1b3628f7affc..97a579b021617 100644 --- a/datafusion/catalog/src/memory/schema.rs +++ b/datafusion/catalog/src/memory/schema.rs @@ -20,7 +20,7 @@ use crate::{SchemaProvider, TableProvider}; use async_trait::async_trait; use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_err}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 90224f6a37bc3..47f773fe9befd 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -27,17 +27,17 @@ use crate::TableProvider; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::error::Result; -use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; +use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; use datafusion_expr::{Expr, SortExpr, TableType}; -use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; +use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, + ExecutionPlan, ExecutionPlanProperties, Partitioning, common, }; use datafusion_session::Session; diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 9ba55256f1824..c6299582813b4 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -19,7 +19,7 @@ //! representing collections of named tables. use async_trait::async_trait; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_err}; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -68,7 +68,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// /// If a table of the same name was already registered, returns "Table /// already exists" error. - #[allow(unused_variables)] + #[expect(unused_variables)] fn register_table( &self, name: String, @@ -81,7 +81,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// schema and returns the previously registered [`TableProvider`], if any. /// /// If no `name` table exists, returns Ok(None). - #[allow(unused_variables)] + #[expect(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { exec_err!("schema provider does not support deregistering tables") } diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index f4a2338b8eecb..bdd72a1b1d70b 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -28,7 +28,7 @@ use std::sync::Arc; use crate::{Session, TableProvider, TableProviderFactory}; use arrow::array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; use arrow::datatypes::SchemaRef; -use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; +use datafusion_common::{Constraints, DataFusionError, Result, config_err, plan_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 082e74dab9a15..31669171b291a 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -24,11 +24,11 @@ use crate::Session; use crate::TableProvider; use arrow::datatypes::SchemaRef; -use datafusion_common::{plan_err, DFSchema, Result}; +use datafusion_common::{DFSchema, Result, plan_err}; use datafusion_expr::{Expr, SortExpr, TableType}; -use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use async_trait::async_trait; use log::debug; diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 11c9af01a7a54..cabdb22c62ae5 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -24,7 +24,7 @@ use crate::session::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; -use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_common::{Constraints, Statistics, not_impl_err}; use datafusion_expr::Expr; use datafusion_expr::dml::InsertOp; diff --git a/datafusion/catalog/src/view.rs b/datafusion/catalog/src/view.rs index 89c6a4a224511..54c54431a5913 100644 --- a/datafusion/catalog/src/view.rs +++ b/datafusion/catalog/src/view.rs @@ -24,8 +24,8 @@ use crate::TableProvider; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::error::Result; use datafusion_common::Column; +use datafusion_common::error::Result; use datafusion_expr::TableType; use datafusion_expr::{Expr, LogicalPlan}; use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index cebd6e04cd1b1..ca618b19ed2f1 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -44,7 +44,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn(trace_future(task)); Self { inner } } @@ -56,7 +56,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn_blocking(trace_block(task)); Self { inner } } @@ -115,14 +115,14 @@ impl Drop for SpawnedTask { mod tests { use super::*; - use std::future::{pending, Pending}; + use std::future::{Pending, pending}; use tokio::{runtime::Runtime, sync::oneshot}; #[tokio::test] async fn runtime_shutdown() { let rt = Runtime::new().unwrap(); - #[allow(clippy::async_yields_async)] + #[expect(clippy::async_yields_async)] let task = rt .spawn(async { SpawnedTask::spawn(async { diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index 5d404d99e7760..fdbfe7f2390ca 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] +#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -31,5 +33,5 @@ mod trace_utils; pub use common::SpawnedTask; pub use join_set::JoinSet; pub use trace_utils::{ - set_join_set_tracer, trace_block, trace_future, JoinSetTracer, JoinSetTracerError, + JoinSetTracer, JoinSetTracerError, set_join_set_tracer, trace_block, trace_future, }; diff --git a/datafusion/common-runtime/src/trace_utils.rs b/datafusion/common-runtime/src/trace_utils.rs index c3a39c355fc88..f8adbe8825bc1 100644 --- a/datafusion/common-runtime/src/trace_utils.rs +++ b/datafusion/common-runtime/src/trace_utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use std::any::Any; use std::error::Error; use std::fmt::{Display, Formatter, Result as FmtResult}; diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index b222ae12b92f5..262f50839563a 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -48,15 +48,18 @@ parquet_encryption = [ "parquet/encryption", "dep:hex", ] -pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] force_hash_collisions = [] recursive_protection = ["dep:recursive"] parquet = ["dep:parquet"] sql = ["sqlparser"] +[[bench]] +harness = false +name = "with_hashes" + [dependencies] ahash = { workspace = true } -apache-avro = { version = "0.20", default-features = false, features = [ +apache-avro = { workspace = true, features = [ "bzip", "snappy", "xz", @@ -73,8 +76,7 @@ libc = "0.2.177" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } -paste = "1.0.15" -pyo3 = { version = "0.26", optional = true } +paste = { workspace = true } recursive = { workspace = true, optional = true } sqlparser = { workspace = true, optional = true } tokio = { workspace = true } @@ -84,6 +86,7 @@ web-time = "1.1.0" [dev-dependencies] chrono = { workspace = true } +criterion = { workspace = true } insta = { workspace = true } rand = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/common/benches/with_hashes.rs b/datafusion/common/benches/with_hashes.rs new file mode 100644 index 0000000000000..8154c20df88f3 --- /dev/null +++ b/datafusion/common/benches/with_hashes.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `with_hashes` function + +use ahash::RandomState; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, + NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, StringViewArray, make_array, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{ArrowDictionaryKeyType, Int32Type, Int64Type}; +use criterion::{Bencher, Criterion, criterion_group, criterion_main}; +use datafusion_common::hash_utils::with_hashes; +use rand::Rng; +use rand::SeedableRng; +use rand::distr::{Alphanumeric, Distribution, StandardUniform}; +use rand::prelude::StdRng; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +struct BenchData { + name: &'static str, + array: ArrayRef, +} + +fn criterion_benchmark(c: &mut Criterion) { + let pool = StringPool::new(100, 64); + // poll with small strings for string view tests (<=12 bytes are inlined) + let small_pool = StringPool::new(100, 5); + let cases = [ + BenchData { + name: "int64", + array: primitive_array::(BATCH_SIZE), + }, + BenchData { + name: "utf8", + array: pool.string_array::(BATCH_SIZE), + }, + BenchData { + name: "large_utf8", + array: pool.string_array::(BATCH_SIZE), + }, + BenchData { + name: "utf8_view", + array: pool.string_view_array(BATCH_SIZE), + }, + BenchData { + name: "utf8_view (small)", + array: small_pool.string_view_array(BATCH_SIZE), + }, + BenchData { + name: "dictionary_utf8_int32", + array: pool.dictionary_array::(BATCH_SIZE), + }, + ]; + + for BenchData { name, array } in cases { + // with_hash has different code paths for single vs multiple arrays and nulls vs no nulls + let nullable_array = add_nulls(&array); + c.bench_function(&format!("{name}: single, no nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&array)); + }); + c.bench_function(&format!("{name}: single, nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&nullable_array)); + }); + c.bench_function(&format!("{name}: multiple, no nulls"), |b| { + let arrays = vec![array.clone(), array.clone(), array.clone()]; + do_hash_test(b, &arrays); + }); + c.bench_function(&format!("{name}: multiple, nulls"), |b| { + let arrays = vec![ + nullable_array.clone(), + nullable_array.clone(), + nullable_array.clone(), + ]; + do_hash_test(b, &arrays); + }); + } +} + +fn do_hash_test(b: &mut Bencher, arrays: &[ArrayRef]) { + let state = RandomState::new(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), BATCH_SIZE); // make sure the result is used + Ok(()) + }) + .unwrap(); + }); +} + +fn create_null_mask(len: usize) -> NullBuffer +where + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + let null_density = 0.03; + let mut builder = NullBufferBuilder::new(len); + for _ in 0..len { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_non_null(); + } + } + builder.finish().expect("should be nulls in buffer") +} + +// Returns an new array that is the same as array, but with nulls +fn add_nulls(array: &ArrayRef) -> ArrayRef { + let array_data = array + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(array.len()))) + .build() + .unwrap(); + make_array(array_data) +} + +pub fn make_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// String pool for generating low cardinality data (for dictionaries and string views) +struct StringPool { + strings: Vec, +} + +impl StringPool { + /// Create a new string pool with the given number of random strings + /// each having between 1 and max_length characters. + fn new(pool_size: usize, max_length: usize) -> Self { + let mut rng = make_rng(); + let mut strings = Vec::with_capacity(pool_size); + for _ in 0..pool_size { + let len = rng.random_range(1..=max_length); + let value: Vec = + rng.clone().sample_iter(&Alphanumeric).take(len).collect(); + strings.push(String::from_utf8(value).unwrap()); + } + Self { strings } + } + + /// Return an iterator over &str of the given length with values randomly chosen from the pool + fn iter_strings(&self, len: usize) -> impl Iterator { + let mut rng = make_rng(); + (0..len).map(move |_| { + let idx = rng.random_range(0..self.strings.len()); + self.strings[idx].as_str() + }) + } + + /// Return a StringArray of the given length with values randomly chosen from the pool + fn string_array(&self, array_length: usize) -> ArrayRef { + Arc::new(GenericStringArray::::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a StringViewArray of the given length with values randomly chosen from the pool + fn string_view_array(&self, array_length: usize) -> ArrayRef { + Arc::new(StringViewArray::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a DictionaryArray of the given length with values randomly chosen from the pool + fn dictionary_array( + &self, + array_length: usize, + ) -> ArrayRef { + Arc::new(DictionaryArray::::from_iter( + self.iter_strings(array_length), + )) + } +} + +pub fn primitive_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + let array: PrimitiveArray = (0..array_len) + .map(|_| Some(rng.random::())) + .collect(); + Arc::new(array) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index b95167ca13908..29082cc303a70 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,11 +20,11 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::{downcast_value, Result}; +use crate::{Result, downcast_value}; use arrow::array::{ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, - Int16Array, Int8Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, + Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, ListViewArray, StringViewArray, UInt16Array, }; use arrow::{ @@ -37,8 +37,8 @@ use arrow::{ MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, - UInt8Array, UnionArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt32Array, + UInt64Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 0ed499da04757..2bea2ec5a4526 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -157,12 +157,10 @@ macro_rules! config_namespace { // $(#[allow(deprecated)])? { $(let value = $transform(value);)? // Apply transformation if specified - #[allow(deprecated)] let ret = self.$field_name.set(rem, value.as_ref()); $(if !$warn.is_empty() { let default: $field_type = $default; - #[allow(deprecated)] if default != self.$field_name { log::warn!($warn); } @@ -181,14 +179,36 @@ macro_rules! config_namespace { $( let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + $( + stringify!($field_name) => { + { + if rem.is_empty() { + let default_value: $field_type = $default; + self.$field_name = default_value; + Ok(()) + } else { + self.$field_name.reset(rem) + } + } + }, + )* + _ => $crate::error::_config_err!( + "Config value \"{}\" not found on {}", + key, + stringify!($struct_name) + ), + } + } } impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -606,6 +626,29 @@ config_namespace! { /// written, it may be necessary to increase this size to avoid errors from /// the remote end point. pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 + + /// Whether to enable ANSI SQL mode. + /// + /// The flag is experimental and relevant only for DataFusion Spark built-in functions + /// + /// When `enable_ansi_mode` is set to `true`, the query engine follows ANSI SQL + /// semantics for expressions, casting, and error handling. This means: + /// - **Strict type coercion rules:** implicit casts between incompatible types are disallowed. + /// - **Standard SQL arithmetic behavior:** operations such as division by zero, + /// numeric overflow, or invalid casts raise runtime errors rather than returning + /// `NULL` or adjusted values. + /// - **Consistent ANSI behavior** for string concatenation, comparisons, and `NULL` handling. + /// + /// When `enable_ansi_mode` is `false` (the default), the engine uses a more permissive, + /// non-ANSI mode designed for user convenience and backward compatibility. In this mode: + /// - Implicit casts between types are allowed (e.g., string to integer when possible). + /// - Arithmetic operations are more lenient — for example, `abs()` on the minimum + /// representable integer value returns the input value instead of raising overflow. + /// - Division by zero or invalid casts may return `NULL` instead of failing. + /// + /// # Default + /// `false` — ANSI SQL mode is disabled by default. + pub enable_ansi_mode: bool, default = false } } @@ -651,6 +694,12 @@ config_namespace! { /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false + /// (reading) Force the use of RowSelections for filter results, when + /// pushdown_filters is enabled. If false, the reader will automatically + /// choose between a RowSelection and a Bitmap based on the number and + /// pattern of selected rows. + pub force_filter_selections: bool, default = false + /// (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, /// and `Binary/BinaryLarge` with `BinaryView`. pub schema_force_view_types: bool, default = true @@ -861,12 +910,16 @@ config_namespace! { /// into the file scan phase. pub enable_join_dynamic_filter_pushdown: bool, default = true - /// When set to true attempts to push down dynamic filters generated by operators (topk & join) into the file scan phase. + /// When set to true, the optimizer will attempt to push down Aggregate dynamic filters + /// into the file scan phase. + pub enable_aggregate_dynamic_filter_pushdown: bool, default = true + + /// When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. /// This means that if we already have 10 timestamps in the year 2025 /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. - /// The config will suppress `enable_join_dynamic_filter_pushdown` & `enable_topk_dynamic_filter_pushdown` + /// The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` /// So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. pub enable_dynamic_filter_pushdown: bool, default = true @@ -912,6 +965,19 @@ config_namespace! { /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true + /// Minimum number of distinct partition values required to group files by their + /// Hive partition column values (enabling Hash partitioning declaration). + /// + /// How the option is used: + /// - preserve_file_partitions=0: Disable it. + /// - preserve_file_partitions=1: Always enable it. + /// - preserve_file_partitions=N, actual file partitions=M: Only enable when M >= N. + /// This threshold preserves I/O parallelism when file partitioning is below it. + /// + /// Note: This may reduce parallelism, rooting from the I/O level, if the number of distinct + /// partitions is less than the target_partitions. + pub preserve_file_partitions: usize, default = 0 + /// Should DataFusion repartition data using the partitions keys to execute window /// functions in parallel using the provided `target_partitions` level pub repartition_windows: bool, default = true @@ -934,6 +1000,34 @@ config_namespace! { /// ``` pub repartition_sorts: bool, default = true + /// Partition count threshold for subset satisfaction optimization. + /// + /// When the current partition count is >= this threshold, DataFusion will + /// skip repartitioning if the required partitioning expression is a subset + /// of the current partition expression such as Hash(a) satisfies Hash(a, b). + /// + /// When the current partition count is < this threshold, DataFusion will + /// repartition to increase parallelism even when subset satisfaction applies. + /// + /// Set to 0 to always repartition (disable subset satisfaction optimization). + /// Set to a high value to always use subset satisfaction. + /// + /// Example (subset_repartition_threshold = 4): + /// ```text + /// Hash([a]) satisfies Hash([a, b]) because (Hash([a, b]) is subset of Hash([a]) + /// + /// If current partitions (3) < threshold (4), repartition: + /// AggregateExec: mode=FinalPartitioned, gby=[a, b], aggr=[SUM(x)] + /// RepartitionExec: partitioning=Hash([a, b], 8), input_partitions=3 + /// AggregateExec: mode=Partial, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 3) + /// + /// If current partitions (8) >= threshold (4), use subset satisfaction: + /// AggregateExec: mode=SinglePartitioned, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 8) + /// ``` + pub subset_repartition_threshold: usize, default = 4 + /// When true, DataFusion will opportunistically remove sorts when the data is already sorted, /// (i.e. setting `preserve_order` to true on `RepartitionExec` and /// using `SortPreservingMergeExec`) @@ -971,6 +1065,36 @@ config_namespace! { /// will be collected into a single partition pub hash_join_single_partition_threshold_rows: usize, default = 1024 * 128 + /// Maximum size in bytes for the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides larger than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// InList pushdown can be more efficient for small build sides because it can result in better + /// statistics pruning as well as use any bloom filters present on the scan side. + /// InList expressions are also more transparent and easier to serialize over the network in distributed uses of DataFusion. + /// On the other hand InList pushdown requires making a copy of the data and thus adds some overhead to the build side and uses more memory. + /// + /// This setting is per-partition, so we may end up using `hash_join_inlist_pushdown_max_size` * `target_partitions` memory. + /// + /// The default is 128kB per partition. + /// This should allow point lookup joins (e.g. joining on a unique primary key) to use InList pushdown in most cases + /// but avoids excessive memory usage or overhead for larger joins. + pub hash_join_inlist_pushdown_max_size: usize, default = 128 * 1024 + + /// Maximum number of distinct values (rows) in the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides with more rows than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// This provides an additional limit beyond `hash_join_inlist_pushdown_max_size` to prevent + /// very large IN lists that might not provide much benefit over hash table lookups. + /// + /// This uses the deduplicated row count once the build side has been evaluated. + /// + /// The default is 150 values per partition. + /// This is inspired by Trino's `max-filter-keys-per-column` setting. + /// See: + pub hash_join_inlist_pushdown_max_distinct_values: usize, default = 150 + /// The default filter selectivity used by Filter Statistics /// when an exact selectivity cannot be determined. Valid values are /// between 0 (no selectivity) and 100 (all rows are selected). @@ -983,6 +1107,21 @@ config_namespace! { /// then the output will be coerced to a non-view. /// Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. pub expand_views_at_output: bool, default = false + + /// Enable sort pushdown optimization. + /// When enabled, attempts to push sort requirements down to data sources + /// that can natively handle them (e.g., by reversing file/row group read order). + /// + /// Returns **inexact ordering**: Sort operator is kept for correctness, + /// but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), + /// providing significant speedup. + /// + /// Memory: No additional overhead (only changes read order). + /// + /// Future: Will add option to detect perfectly sorted data and eliminate Sort completely. + /// + /// Default: true + pub enable_sort_pushdown: bool, default = true } } @@ -1073,7 +1212,7 @@ impl<'a> TryInto> for &'a FormatOptions return _config_err!( "Invalid duration format: {}. Valid values are pretty or iso8601", self.duration_format - ) + ); } }; @@ -1124,6 +1263,15 @@ pub struct ConfigOptions { } impl ConfigField for ConfigOptions { + fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { + self.catalog.visit(v, "datafusion.catalog", ""); + self.execution.visit(v, "datafusion.execution", ""); + self.optimizer.visit(v, "datafusion.optimizer", ""); + self.explain.visit(v, "datafusion.explain", ""); + self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); + } + fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); @@ -1138,13 +1286,43 @@ impl ConfigField for ConfigOptions { } } - fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.catalog.visit(v, "datafusion.catalog", ""); - self.execution.visit(v, "datafusion.execution", ""); - self.optimizer.visit(v, "datafusion.optimizer", ""); - self.explain.visit(v, "datafusion.explain", ""); - self.sql_parser.visit(v, "datafusion.sql_parser", ""); - self.format.visit(v, "datafusion.format", ""); + /// Reset a configuration option back to its default value + fn reset(&mut self, key: &str) -> Result<()> { + let Some((prefix, rest)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; + + if prefix != "datafusion" { + return _config_err!("Could not find config namespace \"{prefix}\""); + } + + let (section, rem) = rest.split_once('.').unwrap_or((rest, "")); + if rem.is_empty() { + return _config_err!("could not find config field for key \"{key}\""); + } + + match section { + "catalog" => self.catalog.reset(rem), + "execution" => self.execution.reset(rem), + "optimizer" => { + if rem == "enable_dynamic_filter_pushdown" { + let defaults = OptimizerOptions::default(); + self.optimizer.enable_dynamic_filter_pushdown = + defaults.enable_dynamic_filter_pushdown; + self.optimizer.enable_topk_dynamic_filter_pushdown = + defaults.enable_topk_dynamic_filter_pushdown; + self.optimizer.enable_join_dynamic_filter_pushdown = + defaults.enable_join_dynamic_filter_pushdown; + Ok(()) + } else { + self.optimizer.reset(rem) + } + } + "explain" => self.explain.reset(rem), + "sql_parser" => self.sql_parser.reset(rem), + "format" => self.format.reset(rem), + other => _config_err!("Config value \"{other}\" not found on ConfigOptions"), + } } } @@ -1178,6 +1356,7 @@ impl ConfigOptions { self.optimizer.enable_dynamic_filter_pushdown = bool_value; self.optimizer.enable_topk_dynamic_filter_pushdown = bool_value; self.optimizer.enable_join_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_aggregate_dynamic_filter_pushdown = bool_value; } return Ok(()); } @@ -1437,6 +1616,14 @@ impl Extensions { let e = self.0.get_mut(T::PREFIX)?; e.0.as_any_mut().downcast_mut() } + + /// Iterates all the config extension entries yielding their prefix and their + /// [ExtensionOptions] implementation. + pub fn iter( + &self, + ) -> impl Iterator)> { + self.0.iter().map(|(k, v)| (*k, &v.0)) + } } #[derive(Debug)] @@ -1454,6 +1641,10 @@ pub trait ConfigField { fn visit(&self, v: &mut V, key: &str, description: &'static str); fn set(&mut self, key: &str, value: &str) -> Result<()>; + + fn reset(&mut self, key: &str) -> Result<()> { + _config_err!("Reset is not supported for this config field, key: {}", key) + } } impl ConfigField for Option { @@ -1467,6 +1658,15 @@ impl ConfigField for Option { fn set(&mut self, key: &str, value: &str) -> Result<()> { self.get_or_insert_with(Default::default).set(key, value) } + + fn reset(&mut self, key: &str) -> Result<()> { + if key.is_empty() { + *self = Default::default(); + Ok(()) + } else { + self.get_or_insert_with(Default::default).reset(key) + } + } } /// Default transformation to parse a [`ConfigField`] for a string. @@ -1531,6 +1731,19 @@ macro_rules! config_field { *self = $transform; Ok(()) } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + if key.is_empty() { + *self = <$t as Default>::default(); + Ok(()) + } else { + $crate::error::_config_err!( + "Config field is a scalar {} and does not have nested field \"{}\"", + stringify!($t), + key + ) + } + } } }; } @@ -1540,6 +1753,7 @@ config_field!(bool, value => default_config_transform(value.to_lowercase().as_st config_field!(usize); config_field!(f64); config_field!(u64); +config_field!(u32); impl ConfigField for u8 { fn visit(&self, v: &mut V, key: &str, description: &'static str) { @@ -1730,8 +1944,7 @@ macro_rules! extensions_options { // Safely apply deprecated attribute if present // $(#[allow(deprecated)])? { - #[allow(deprecated)] - self.$field_name.set(rem, value.as_ref()) + self.$field_name.set(rem, value.as_ref()) } }, )* @@ -1745,7 +1958,6 @@ macro_rules! extensions_options { $( let key = stringify!($field_name).to_string(); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -2136,13 +2348,13 @@ impl ConfigField for TableParquetOptions { [_meta] | [_meta, ""] => { return _config_err!( "Invalid metadata key provided, missing key in metadata::" - ) + ); } [_meta, k] => k.into(), _ => { return _config_err!( "Invalid metadata key provided, found too many '::' in \"{key}\"" - ) + ); } }; self.key_value_metadata.insert(k, Some(value.into())); @@ -2188,7 +2400,6 @@ macro_rules! config_namespace_with_hashmap { $( stringify!($field_name) => { // Handle deprecated fields - #[allow(deprecated)] // Allow deprecated fields $(let value = $transform(value);)? self.$field_name.set(rem, value.as_ref()) }, @@ -2204,7 +2415,6 @@ macro_rules! config_namespace_with_hashmap { let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); // Handle deprecated fields - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -2212,7 +2422,6 @@ macro_rules! config_namespace_with_hashmap { impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -2240,7 +2449,6 @@ macro_rules! config_namespace_with_hashmap { $( let key = format!("{}.{field}::{}", key_prefix, column_name, field = stringify!($field_name)); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] col_options.$field_name.visit(v, key.as_str(), desc); )* } @@ -2539,7 +2747,7 @@ impl ConfigField for ConfigFileDecryptionProperties { self.footer_signature_verification.set(rem, value.as_ref()) } _ => _config_err!( - "Config value \"{}\" not found on ConfigFileEncryptionProperties", + "Config value \"{}\" not found on ConfigFileDecryptionProperties", key ), } @@ -2665,6 +2873,14 @@ config_namespace! { /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None pub date_format: Option, default = None pub datetime_format: Option, default = None @@ -2787,6 +3003,14 @@ impl CsvOptions { self } + /// Set the compression level for the output file. + /// The valid range depends on the compression algorithm. + /// If not specified, the default level for the algorithm is used. + pub fn with_compression_level(mut self, level: u32) -> Self { + self.compression_level = Some(level); + self + } + /// The delimiter character. pub fn delimiter(&self) -> u8 { self.delimiter @@ -2812,6 +3036,14 @@ config_namespace! { /// Options controlling JSON format pub struct JsonOptions { pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None } } @@ -2819,7 +3051,7 @@ config_namespace! { pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] -#[allow(clippy::large_enum_variant)] +#[cfg_attr(feature = "parquet", expect(clippy::large_enum_variant))] pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), @@ -2853,7 +3085,6 @@ mod tests { }; use std::any::Any; use std::collections::HashMap; - use std::sync::Arc; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2925,6 +3156,16 @@ mod tests { ); } + #[test] + fn iter_test_extension_config() { + let mut extension = Extensions::new(); + extension.insert(TestExtensionConfig::default()); + let table_config = TableOptions::new().with_extensions(extension); + let extensions = table_config.extensions.iter().collect::>(); + assert_eq!(extensions.len(), 1); + assert_eq!(extensions[0].0, TestExtensionConfig::PREFIX); + } + #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); @@ -2968,6 +3209,19 @@ mod tests { assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 1); } + #[test] + fn reset_nested_scalar_reports_helpful_error() { + let mut value = true; + let err = ::reset(&mut value, "nested").unwrap_err(); + let message = err.to_string(); + assert!( + message.starts_with( + "Invalid or Unsupported Configuration: Config field is a scalar bool and does not have nested field \"nested\"" + ), + "unexpected error message: {message}" + ); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { @@ -2990,6 +3244,7 @@ mod tests { }; use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::encryption::encrypt::FileEncryptionProperties; + use std::sync::Arc; let footer_key = b"0123456789012345".to_vec(); // 128bit/16 let column_names = vec!["double_field", "float_field"]; @@ -3143,9 +3398,11 @@ mod tests { .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); - assert!(entries - .iter() - .any(|item| item.key == "format.bloom_filter_enabled::col1")) + assert!( + entries + .iter() + .any(|item| item.key == "format.bloom_filter_enabled::col1") + ) } #[cfg(feature = "parquet")] @@ -3159,10 +3416,10 @@ mod tests { ) .unwrap(); let entries = table_parquet_options.entries(); - assert!(entries - .iter() - .any(|item| item.key - == "crypto.file_encryption.column_key_as_hex::double_field")) + assert!( + entries.iter().any(|item| item.key + == "crypto.file_encryption.column_key_as_hex::double_field") + ) } #[cfg(feature = "parquet")] diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..93169d6a02ff1 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -19,12 +19,12 @@ //! a [`CSEController`], that defines how to eliminate common subtrees from a particular //! [`TreeNode`] tree. +use crate::Result; use crate::hash_utils::combine_hashes; use crate::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use crate::Result; use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; @@ -676,13 +676,13 @@ where #[cfg(test)] mod test { + use crate::Result; use crate::alias::AliasGenerator; use crate::cse::{ - CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, - Normalizeable, CSE, + CSE, CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, }; use crate::tree_node::tests::TestTreeNode; - use crate::Result; use std::collections::HashSet; use std::hash::{Hash, Hasher}; diff --git a/datafusion/common/src/datatype.rs b/datafusion/common/src/datatype.rs index 65f6395211866..19847f8583505 100644 --- a/datafusion/common/src/datatype.rs +++ b/datafusion/common/src/datatype.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with DataTypes to Fields +//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with Arrow [`DataType`] and [`Field`]s use crate::arrow::datatypes::{DataType, Field, FieldRef}; +use crate::metadata::FieldMetadata; use std::sync::Arc; /// DataFusion extension methods for Arrow [`DataType`] @@ -61,7 +62,54 @@ impl DataTypeExt for DataType { } /// DataFusion extension methods for Arrow [`Field`] and [`FieldRef`] +/// +/// This trait is implemented for both [`Field`] and [`FieldRef`] and +/// provides convenience methods for efficiently working with both types. +/// +/// For [`FieldRef`], the methods will attempt to unwrap the `Arc` +/// to avoid unnecessary cloning when possible. pub trait FieldExt { + /// Ensure the field is named `new_name`, returning the given field if the + /// name matches, and a new field if not. + /// + /// This method avoids `clone`ing fields and names if the name is the same + /// as the field's existing name. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // rename to "your_int" + /// let renamed_field = int_field.renamed("your_int"); + /// assert_eq!(renamed_field.name(), "your_int"); + /// ``` + fn renamed(self, new_name: &str) -> Self; + + /// Ensure the field has the given data type + /// + /// Note this is different than simply calling [`Field::with_data_type`] as + /// it avoids copying if the data type is already the same. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // change to Float64 + /// let retyped_field = int_field.retyped(DataType::Float64); + /// assert_eq!(retyped_field.data_type(), &DataType::Float64); + /// ``` + fn retyped(self, new_data_type: DataType) -> Self; + + /// Add field metadata to the Field + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self; + + /// Add optional field metadata, + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self; + /// Returns a new Field representing a List of this Field's DataType. /// /// For example if input represents an `Int32`, the return value will @@ -130,6 +178,32 @@ pub trait FieldExt { } impl FieldExt for Field { + fn renamed(self, new_name: &str) -> Self { + // check if this is a new name before allocating a new Field / copying + // the existing one + if self.name() != new_name { + self.with_name(new_name) + } else { + self + } + } + + fn retyped(self, new_data_type: DataType) -> Self { + self.with_data_type(new_data_type) + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + fn into_list(self) -> Self { DataType::List(Arc::new(self.into_list_item())).into_nullable_field() } @@ -149,6 +223,34 @@ impl FieldExt for Field { } impl FieldExt for Arc { + fn renamed(mut self, new_name: &str) -> Self { + if self.name() != new_name { + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(new_name); + } + self + } + + fn retyped(mut self, new_data_type: DataType) -> Self { + if self.data_type() != &new_data_type { + // avoid cloning if possible + Arc::make_mut(&mut self).set_data_type(new_data_type); + } + self + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field_ref(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + fn into_list(self) -> Self { DataType::List(self.into_list_item()) .into_nullable_field() @@ -161,13 +263,11 @@ impl FieldExt for Arc { .into() } - fn into_list_item(self) -> Self { + fn into_list_item(mut self) -> Self { if self.name() != Field::LIST_FIELD_DEFAULT_NAME { - Arc::unwrap_or_clone(self) - .with_name(Field::LIST_FIELD_DEFAULT_NAME) - .into() - } else { - self + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(Field::LIST_FIELD_DEFAULT_NAME); } + self } } diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 24d152a7dba8c..55a031d870122 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -23,10 +23,10 @@ use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::sync::Arc; -use crate::error::{DataFusionError, Result, _plan_err, _schema_err}; +use crate::error::{_plan_err, _schema_err, DataFusionError, Result}; use crate::{ - field_not_found, unqualified_field_not_found, Column, FunctionalDependencies, - SchemaError, TableReference, + Column, FunctionalDependencies, SchemaError, TableReference, field_not_found, + unqualified_field_not_found, }; use arrow::compute::can_cast_types; @@ -37,7 +37,7 @@ use arrow::datatypes::{ /// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names. +/// DFSchema wraps an Arrow schema and add a relation (table) name. /// /// The schema may hold the fields across multiple tables. Some fields may be /// qualified and some unqualified. A qualified field is a field that has a @@ -47,8 +47,14 @@ pub type DFSchemaRef = Arc; /// have a distinct name from any qualified field names. This allows finding a /// qualified field by name to be possible, so long as there aren't multiple /// qualified fields with the same name. +///] +/// # See Also +/// * [DFSchemaRef], an alias to `Arc` +/// * [DataTypeExt], common methods for working with Arrow [DataType]s +/// * [FieldExt], extension methods for working with Arrow [Field]s /// -/// There is an alias to `Arc` named [DFSchemaRef]. +/// [DataTypeExt]: crate::datatype::DataTypeExt +/// [FieldExt]: crate::datatype::FieldExt /// /// # Creating qualified schemas /// @@ -346,20 +352,22 @@ impl DFSchema { self.field_qualifiers.extend(qualifiers); } - /// Get a list of fields + /// Get a list of fields for this schema pub fn fields(&self) -> &Fields { &self.inner.fields } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &Field { + /// Returns a reference to [`FieldRef`] for a column at specific index + /// within the schema. + /// + /// See also [Self::qualified_field] to get both qualifier and field + pub fn field(&self, i: usize) -> &FieldRef { &self.inner.fields[i] } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector and its qualifier - pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &Field) { + /// Returns the qualifier (if any) and [`FieldRef`] for a column at specific + /// index within the schema. + pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &FieldRef) { (self.field_qualifiers[i].as_ref(), self.field(i)) } @@ -410,12 +418,12 @@ impl DFSchema { .is_some() } - /// Find the field with the given name + /// Find the [`FieldRef`] with the given name and optional qualifier pub fn field_with_name( &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { if let Some(qualifier) = qualifier { self.field_with_qualified_name(qualifier, name) } else { @@ -428,7 +436,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { if let Some(qualifier) = qualifier { let idx = self .index_of_column_by_name(Some(qualifier), name) @@ -440,10 +448,10 @@ impl DFSchema { } /// Find all fields having the given qualifier - pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&Field> { + pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&FieldRef> { self.iter() .filter(|(q, _)| q.map(|q| q.eq(qualifier)).unwrap_or(false)) - .map(|(_, f)| f.as_ref()) + .map(|(_, f)| f) .collect() } @@ -459,11 +467,10 @@ impl DFSchema { } /// Find all fields that match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&Field> { + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&FieldRef> { self.fields() .iter() .filter(|field| field.name() == name) - .map(|f| f.as_ref()) .collect() } @@ -471,10 +478,9 @@ impl DFSchema { pub fn qualified_fields_with_unqualified_name( &self, name: &str, - ) -> Vec<(Option<&TableReference>, &Field)> { + ) -> Vec<(Option<&TableReference>, &FieldRef)> { self.iter() .filter(|(_, field)| field.name() == name) - .map(|(qualifier, field)| (qualifier, field.as_ref())) .collect() } @@ -499,7 +505,7 @@ impl DFSchema { pub fn qualified_field_with_unqualified_name( &self, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), @@ -528,7 +534,7 @@ impl DFSchema { } /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&FieldRef> { self.qualified_field_with_unqualified_name(name) .map(|(_, field)| field) } @@ -538,7 +544,7 @@ impl DFSchema { &self, qualifier: &TableReference, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { let idx = self .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; @@ -550,7 +556,7 @@ impl DFSchema { pub fn qualified_field_from_column( &self, column: &Column, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { self.qualified_field_with_name(column.relation.as_ref(), &column.name) } @@ -982,36 +988,35 @@ fn format_field_with_indent( result.push_str(&format!( "{indent}|-- {field_name}: map (nullable = {nullable_str})\n" )); - if let DataType::Struct(inner_fields) = field.data_type() { - if inner_fields.len() == 2 { - format_field_with_indent( - result, - "key", - inner_fields[0].data_type(), - inner_fields[0].is_nullable(), - &child_indent, - ); - let value_contains_null = - field.is_nullable().to_string().to_lowercase(); - // Handle complex value types properly - match inner_fields[1].data_type() { - DataType::Struct(_) - | DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) - | DataType::Map(_, _) => { - format_field_with_indent( - result, - "value", - inner_fields[1].data_type(), - inner_fields[1].is_nullable(), - &child_indent, - ); - } - _ => { - result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", + if let DataType::Struct(inner_fields) = field.data_type() + && inner_fields.len() == 2 + { + format_field_with_indent( + result, + "key", + inner_fields[0].data_type(), + inner_fields[0].is_nullable(), + &child_indent, + ); + let value_contains_null = field.is_nullable().to_string().to_lowercase(); + // Handle complex value types properly + match inner_fields[1].data_type() { + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Map(_, _) => { + format_field_with_indent( + result, + "value", + inner_fields[1].data_type(), + inner_fields[1].is_nullable(), + &child_indent, + ); + } + _ => { + result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", format_simple_data_type(inner_fields[1].data_type()))); - } } } } @@ -1221,7 +1226,7 @@ pub trait ExprSchema: std::fmt::Debug { } // Return the column's field - fn field_from_column(&self, col: &Column) -> Result<&Field>; + fn field_from_column(&self, col: &Column) -> Result<&FieldRef>; } // Implement `ExprSchema` for `Arc` @@ -1242,13 +1247,13 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().data_type_and_nullable(col) } - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { self.as_ref().field_from_column(col) } } impl ExprSchema for DFSchema { - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { match &col.relation { Some(r) => self.field_with_qualified_name(r, &col.name), None => self.field_with_unqualified_name(&col.name), @@ -1433,12 +1438,14 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); - assert!(join - .field_with_qualified_name(&TableReference::bare("t2"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); + assert!( + join.field_with_qualified_name(&TableReference::bare("t2"), "c0") + .is_ok() + ); // test invalid access assert!(join.field_with_unqualified_name("c0").is_err()); assert!(join.field_with_unqualified_name("t1.c0").is_err()); @@ -1480,18 +1487,20 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); assert!(join.field_with_unqualified_name("c0").is_ok()); assert!(join.field_with_unqualified_name("c100").is_ok()); assert!(join.field_with_name(None, "c100").is_ok()); // test invalid access assert!(join.field_with_unqualified_name("t1.c0").is_err()); assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join - .field_with_qualified_name(&TableReference::bare(""), "c100") - .is_err()); + assert!( + join.field_with_qualified_name(&TableReference::bare(""), "c100") + .is_err() + ); Ok(()) } @@ -1500,9 +1509,11 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); - assert_contains!(join.unwrap_err().to_string(), - "Schema error: Schema contains qualified \ - field name t1.c0 and unqualified field name c0 which would be ambiguous"); + assert_contains!( + join.unwrap_err().to_string(), + "Schema error: Schema contains qualified \ + field name t1.c0 and unqualified field name c0 which would be ambiguous" + ); Ok(()) } @@ -2059,7 +2070,7 @@ mod tests { fn test_print_schema_empty() { let schema = DFSchema::empty(); let output = schema.tree_string(); - insta::assert_snapshot!(output, @r###"root"###); + insta::assert_snapshot!(output, @"root"); } #[test] diff --git a/datafusion/common/src/display/human_readable.rs b/datafusion/common/src/display/human_readable.rs new file mode 100644 index 0000000000000..0e0d677bd8904 --- /dev/null +++ b/datafusion/common/src/display/human_readable.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for rendering sizes, counts, and durations in human readable form. + +/// Common data size units +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} + +/// Present size in human-readable form +pub fn human_readable_size(size: usize) -> String { + use units::*; + + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{value:.1} {unit}") +} + +/// Present count in human-readable form with K, M, B, T suffixes +pub fn human_readable_count(count: usize) -> String { + let count = count as u64; + let (value, unit) = { + if count >= 1_000_000_000_000 { + (count as f64 / 1_000_000_000_000.0, " T") + } else if count >= 1_000_000_000 { + (count as f64 / 1_000_000_000.0, " B") + } else if count >= 1_000_000 { + (count as f64 / 1_000_000.0, " M") + } else if count >= 1_000 { + (count as f64 / 1_000.0, " K") + } else { + return count.to_string(); + } + }; + + // Format with appropriate precision + // For values >= 100, show 1 decimal place (e.g., 123.4 K) + // For values < 100, show 2 decimal places (e.g., 10.12 K) + if value >= 100.0 { + format!("{value:.1}{unit}") + } else { + format!("{value:.2}{unit}") + } +} + +/// Present duration in human-readable form with 2 decimal places +pub fn human_readable_duration(nanos: u64) -> String { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + const NANOS_PER_MILLI: f64 = 1_000_000.0; + const NANOS_PER_MICRO: f64 = 1_000.0; + + let nanos_f64 = nanos as f64; + + if nanos >= 1_000_000_000 { + // >= 1 second: show in seconds + format!("{:.2}s", nanos_f64 / NANOS_PER_SEC) + } else if nanos >= 1_000_000 { + // >= 1 millisecond: show in milliseconds + format!("{:.2}ms", nanos_f64 / NANOS_PER_MILLI) + } else if nanos >= 1_000 { + // >= 1 microsecond: show in microseconds + format!("{:.2}µs", nanos_f64 / NANOS_PER_MICRO) + } else { + // < 1 microsecond: show in nanoseconds + format!("{nanos}ns") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_human_readable_count() { + assert_eq!(human_readable_count(0), "0"); + assert_eq!(human_readable_count(1), "1"); + assert_eq!(human_readable_count(999), "999"); + assert_eq!(human_readable_count(1_000), "1.00 K"); + assert_eq!(human_readable_count(10_100), "10.10 K"); + assert_eq!(human_readable_count(1_532), "1.53 K"); + assert_eq!(human_readable_count(99_999), "100.00 K"); + assert_eq!(human_readable_count(1_000_000), "1.00 M"); + assert_eq!(human_readable_count(1_532_000), "1.53 M"); + assert_eq!(human_readable_count(99_000_000), "99.00 M"); + assert_eq!(human_readable_count(123_456_789), "123.5 M"); + assert_eq!(human_readable_count(1_000_000_000), "1.00 B"); + assert_eq!(human_readable_count(1_532_000_000), "1.53 B"); + assert_eq!(human_readable_count(999_999_999_999), "1000.0 B"); + assert_eq!(human_readable_count(1_000_000_000_000), "1.00 T"); + assert_eq!(human_readable_count(42_000_000_000_000), "42.00 T"); + } + + #[test] + fn test_human_readable_duration() { + assert_eq!(human_readable_duration(0), "0ns"); + assert_eq!(human_readable_duration(1), "1ns"); + assert_eq!(human_readable_duration(999), "999ns"); + assert_eq!(human_readable_duration(1_000), "1.00µs"); + assert_eq!(human_readable_duration(1_234), "1.23µs"); + assert_eq!(human_readable_duration(999_999), "1000.00µs"); + assert_eq!(human_readable_duration(1_000_000), "1.00ms"); + assert_eq!(human_readable_duration(11_295_377), "11.30ms"); + assert_eq!(human_readable_duration(1_234_567), "1.23ms"); + assert_eq!(human_readable_duration(999_999_999), "1000.00ms"); + assert_eq!(human_readable_duration(1_000_000_000), "1.00s"); + assert_eq!(human_readable_duration(1_234_567_890), "1.23s"); + assert_eq!(human_readable_duration(42_000_000_000), "42.00s"); + } +} diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index bad51c45f8ee8..a6a97b243f06a 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -18,6 +18,7 @@ //! Types for plan display mod graphviz; +pub mod human_readable; pub use graphviz::*; use std::{ diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index fde52944d0497..4f681896dfc66 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -15,7 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion error types +//! # Error Handling in DataFusion +//! +//! In DataFusion, there are two types of errors that can be raised: +//! +//! 1. Expected errors – These indicate invalid operations performed by the caller, +//! such as attempting to open a non-existent file. Different categories exist to +//! distinguish their sources (e.g., [`DataFusionError::ArrowError`], +//! [`DataFusionError::IoError`], etc.). +//! +//! 2. Unexpected errors – Represented by [`DataFusionError::Internal`], these +//! indicate that an internal invariant has been broken, suggesting a potential +//! bug in the system. +//! +//! There are several convenient macros for throwing errors. For example, use +//! `exec_err!` for expected errors. +//! For invariant checks, you can use `assert_or_internal_err!`, +//! `assert_eq_or_internal_err!`, `assert_ne_or_internal_err!` for easier assertions. +//! On the performance-critical path, use `debug_assert!` instead to reduce overhead. + #[cfg(feature = "backtrace")] use std::backtrace::{Backtrace, BacktraceStatus}; @@ -153,6 +171,10 @@ pub enum DataFusionError { /// to multiple receivers. For example, when the source of a repartition /// errors and the error is propagated to multiple consumers. Shared(Arc), + /// An error that originated during a foreign function interface call. + /// Transferring errors across the FFI boundary is difficult, so the original + /// error will be converted to a string. + Ffi(String), } #[macro_export] @@ -395,6 +417,7 @@ impl Error for DataFusionError { // can't be executed. DataFusionError::Collection(errs) => errs.first().map(|e| e as &dyn Error), DataFusionError::Shared(e) => Some(e.as_ref()), + DataFusionError::Ffi(_) => None, } } } @@ -526,6 +549,7 @@ impl DataFusionError { errs.first().expect("cannot construct DataFusionError::Collection with 0 errors, but got one such case").error_prefix() } DataFusionError::Shared(_) => "", + DataFusionError::Ffi(_) => "FFI error: ", } } @@ -578,6 +602,7 @@ impl DataFusionError { .expect("cannot construct DataFusionError::Collection with 0 errors") .message(), DataFusionError::Shared(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::Ffi(ref desc) => Cow::Owned(desc.to_string()), } } @@ -750,7 +775,7 @@ impl DataFusionErrorBuilder { macro_rules! unwrap_or_internal_err { ($Value: ident) => { $Value.ok_or_else(|| { - DataFusionError::Internal(format!( + $crate::DataFusionError::Internal(format!( "{} should not be None", stringify!($Value) )) @@ -758,6 +783,116 @@ macro_rules! unwrap_or_internal_err { }; } +/// Assert a condition, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_or_internal_err!(predicate); +/// assert_or_internal_err!(predicate, "human readable message"); +/// assert_or_internal_err!(predicate, format!("details: {}", value)); +/// ``` +#[macro_export] +macro_rules! assert_or_internal_err { + ($cond:expr) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}", + stringify!($cond) + ))); + } + }; + ($cond:expr, $($arg:tt)+) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}: {}", + stringify!($cond), + format!($($arg)+) + ))); + } + }; +} + +/// Assert equality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_eq_or_internal_err!(actual, expected); +/// assert_eq_or_internal_err!(left_expr, right_expr, "values must match"); +/// assert_eq_or_internal_err!(lhs, rhs, "metadata: {}", extra); +/// ``` +#[macro_export] +macro_rules! assert_eq_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + +/// Assert inequality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_ne_or_internal_err!(left, right); +/// assert_ne_or_internal_err!(lhs_expr, rhs_expr, "values must differ"); +/// assert_ne_or_internal_err!(a, b, "context {}", info); +/// ``` +#[macro_export] +macro_rules! assert_ne_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + /// Add a macros for concise DataFusionError::* errors declaration /// supports placeholders the same way as `format!` /// Examples: @@ -807,14 +942,9 @@ macro_rules! make_error { } - // Note: Certain macros are used in this crate, but not all. - // This macro generates a use or all of them in case they are needed - // so we allow unused code to avoid warnings when they are not used #[doc(hidden)] - #[allow(unused)] pub use $NAME_ERR as [<_ $NAME_ERR>]; #[doc(hidden)] - #[allow(unused)] pub use $NAME_DF_ERR as [<_ $NAME_DF_ERR>]; } }; @@ -841,11 +971,14 @@ make_error!(substrait_err, substrait_datafusion_err, Substrait); // Exposes a macro to create `DataFusionError::ResourcesExhausted` with optional backtrace make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); +// Exposes a macro to create `DataFusionError::Ffi` with optional backtrace +make_error!(ffi_err, ffi_datafusion_err, Ffi); + // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::SQL(Box::new($ERR), Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::SQL(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -857,7 +990,7 @@ macro_rules! sql_datafusion_err { #[macro_export] macro_rules! sql_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = datafusion_common::sql_datafusion_err!($ERR); + let err = $crate::sql_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -869,7 +1002,7 @@ macro_rules! sql_err { #[macro_export] macro_rules! arrow_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::ArrowError(Box::new($ERR), Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::ArrowError(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -882,7 +1015,7 @@ macro_rules! arrow_datafusion_err { macro_rules! arrow_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { { - let err = datafusion_common::arrow_datafusion_err!($ERR); + let err = $crate::arrow_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -894,9 +1027,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( + let err = $crate::DataFusionError::SchemaError( Box::new($ERR), - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -909,9 +1042,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( + let err = $crate::DataFusionError::SchemaError( Box::new($ERR), - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -974,6 +1107,115 @@ mod test { use std::sync::Arc; use arrow::error::ArrowError; + use insta::assert_snapshot; + + fn ok_result() -> Result<()> { + Ok(()) + } + + #[test] + fn test_assert_eq_or_internal_err_passes() -> Result<()> { + assert_eq_or_internal_err!(1, 1); + ok_result() + } + + #[test] + fn test_assert_eq_or_internal_err_fails() { + fn check() -> Result<()> { + assert_eq_or_internal_err!(1, 2, "expected equality"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 1 == 2 (left: 1, right: 2): expected equality. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_ne_or_internal_err_passes() -> Result<()> { + assert_ne_or_internal_err!(1, 2); + ok_result() + } + + #[test] + fn test_assert_ne_or_internal_err_fails() { + fn check() -> Result<()> { + assert_ne_or_internal_err!(3, 3, "values must differ"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 3 != 3 (left: 3, right: 3): values must differ. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_passes() -> Result<()> { + assert_or_internal_err!(true); + assert_or_internal_err!(true, "message"); + ok_result() + } + + #[test] + fn test_assert_or_internal_err_fails_default() { + fn check() -> Result<()> { + assert_or_internal_err!(false); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_fails_with_message() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom message"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom message. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_with_format_arguments() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom {}", 42); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom 42. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } #[test] fn test_error_size() { @@ -986,9 +1228,10 @@ mod test { #[test] fn datafusion_error_to_arrow() { let res = return_arrow_error().unwrap_err(); - assert!(res - .to_string() - .starts_with("External error: Error during planning: foo")); + assert!( + res.to_string() + .starts_with("External error: Error during planning: foo") + ); } #[test] @@ -1000,7 +1243,7 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] + #[expect(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} @@ -1017,17 +1260,17 @@ mod test { .unwrap(), &"Error during planning: Err" ); - assert!(!err - .split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .is_empty()); + assert!( + !err.split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty() + ); } #[cfg(not(feature = "backtrace"))] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_disabled_backtrace() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err().to_string(); @@ -1097,7 +1340,6 @@ mod test { } #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_make_error_parse_input() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err(); @@ -1166,9 +1408,11 @@ mod test { let external_error_2: DataFusionError = generic_error_2.into(); println!("{external_error_2}"); - assert!(external_error_2 - .to_string() - .starts_with("External error: io error")); + assert!( + external_error_2 + .to_string() + .starts_with("External error: io error") + ); } /// Model what happens when implementing SendableRecordBatchStream: diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 943288af91642..4e6f74a4448af 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -31,6 +31,8 @@ pub struct CsvWriterOptions { /// Compression to apply after ArrowWriter serializes RecordBatches. /// This compression is applied by DataFusion not the ArrowWriter itself. pub compression: CompressionTypeVariant, + /// Compression level for the output file. + pub compression_level: Option, } impl CsvWriterOptions { @@ -41,6 +43,20 @@ impl CsvWriterOptions { Self { writer_options, compression, + compression_level: None, + } + } + + /// Create a new `CsvWriterOptions` with the specified compression level. + pub fn new_with_level( + writer_options: WriterBuilder, + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + writer_options, + compression, + compression_level: Some(compression_level), } } } @@ -81,6 +97,7 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/json_writer.rs b/datafusion/common/src/file_options/json_writer.rs index 750d2972329bb..a537192c8128a 100644 --- a/datafusion/common/src/file_options/json_writer.rs +++ b/datafusion/common/src/file_options/json_writer.rs @@ -27,11 +27,26 @@ use crate::{ #[derive(Clone, Debug)] pub struct JsonWriterOptions { pub compression: CompressionTypeVariant, + pub compression_level: Option, } impl JsonWriterOptions { pub fn new(compression: CompressionTypeVariant) -> Self { - Self { compression } + Self { + compression, + compression_level: None, + } + } + + /// Create a new `JsonWriterOptions` with the specified compression and level. + pub fn new_with_level( + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + compression, + compression_level: Some(compression_level), + } } } @@ -41,6 +56,7 @@ impl TryFrom<&JsonOptions> for JsonWriterOptions { fn try_from(value: &JsonOptions) -> Result { Ok(JsonWriterOptions { compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 02667e0165717..c7374949ecef5 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -31,10 +31,10 @@ mod tests { use std::collections::HashMap; use crate::{ + Result, config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, }; use parquet::{ diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 564929c61bab0..8aa0134d09ec8 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -20,22 +20,20 @@ use std::sync::Arc; use crate::{ + _internal_datafusion_err, DataFusionError, Result, config::{ParquetOptions, TableParquetOptions}, - DataFusionError, Result, _internal_datafusion_err, }; use arrow::datatypes::Schema; use parquet::arrow::encode_arrow_schema; -// TODO: handle once deprecated -#[allow(deprecated)] use parquet::{ arrow::ARROW_SCHEMA_META_KEY, basic::{BrotliLevel, GzipLevel, ZstdLevel}, file::{ metadata::KeyValue, properties::{ - EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, - DEFAULT_STATISTICS_ENABLED, + DEFAULT_STATISTICS_ENABLED, EnabledStatistics, WriterProperties, + WriterPropertiesBuilder, WriterVersion, }, }, schema::types::ColumnPath, @@ -106,7 +104,9 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { if !global.skip_arrow_metadata && !key_value_metadata.contains_key(ARROW_SCHEMA_META_KEY) { - return Err(_internal_datafusion_err!("arrow schema was not added to the kv_metadata, even though it is required by configuration settings")); + return Err(_internal_datafusion_err!( + "arrow schema was not added to the kv_metadata, even though it is required by configuration settings" + )); } // add kv_meta, if any @@ -174,7 +174,6 @@ impl ParquetOptions { /// /// Note that this method does not include the key_value_metadata from [`TableParquetOptions`]. pub fn into_writer_properties_builder(&self) -> Result { - #[allow(deprecated)] let ParquetOptions { data_pagesize_limit, write_batch_size, @@ -200,6 +199,7 @@ impl ParquetOptions { metadata_size_hint: _, pushdown_filters: _, reorder_filters: _, + force_filter_selections: _, // not used for writer props allow_single_file_parallelism: _, maximum_parallel_row_group_writers: _, maximum_buffered_record_batches_per_stream: _, @@ -261,7 +261,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), - #[allow(deprecated)] + #[expect(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { @@ -402,14 +402,13 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetColumnOptions { let bloom_filter_default_props = props.bloom_filter_properties(&col); - #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { bloom_filter_enabled: Some(bloom_filter_default_props.is_some()), encoding: props.encoding(&col).map(|s| s.to_string()), @@ -545,7 +543,6 @@ mod tests { #[cfg(not(feature = "parquet_encryption"))] let fep = None; - #[allow(deprecated)] // max_statistics_size TableParquetOptions { global: ParquetOptions { // global options @@ -577,6 +574,7 @@ mod tests { metadata_size_hint: global_options_defaults.metadata_size_hint, pushdown_filters: global_options_defaults.pushdown_filters, reorder_filters: global_options_defaults.reorder_filters, + force_filter_selections: global_options_defaults.force_filter_selections, allow_single_file_parallelism: global_options_defaults .allow_single_file_parallelism, maximum_parallel_row_group_writers: global_options_defaults @@ -674,8 +672,7 @@ mod tests { let mut default_table_writer_opts = TableParquetOptions::default(); let default_parquet_opts = ParquetOptions::default(); assert_eq!( - default_table_writer_opts.global, - default_parquet_opts, + default_table_writer_opts.global, default_parquet_opts, "should have matching defaults for TableParquetOptions.global and ParquetOptions", ); @@ -699,7 +696,9 @@ mod tests { "should have different created_by sources", ); assert!( - default_writer_props.created_by().starts_with("parquet-rs version"), + default_writer_props + .created_by() + .starts_with("parquet-rs version"), "should indicate that writer_props defaults came from the extern parquet crate", ); assert!( @@ -733,8 +732,7 @@ mod tests { from_extern_parquet.global.skip_arrow_metadata = true; assert_eq!( - default_table_writer_opts, - from_extern_parquet, + default_table_writer_opts, from_extern_parquet, "the default writer_props should have the same configuration as the session's default TableParquetOptions", ); } diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 764190e1189bf..a505bd0e1c74e 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -176,9 +176,9 @@ impl FromStr for ExplainFormat { "tree" => Ok(ExplainFormat::Tree), "pgjson" => Ok(ExplainFormat::PostgresJSON), "graphviz" => Ok(ExplainFormat::Graphviz), - _ => { - Err(DataFusionError::Configuration(format!("Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'"))) - } + _ => Err(DataFusionError::Configuration(format!( + "Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'" + ))), } } } diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index d60189fb6fa3f..98dd1f235aee7 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -28,11 +28,11 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_string_array, as_string_view_array, as_struct_array, + as_string_array, as_string_view_array, as_struct_array, as_union_array, }; use crate::error::Result; -#[cfg(not(feature = "force_hash_collisions"))] -use crate::error::_internal_err; +use crate::error::{_internal_datafusion_err, _internal_err}; +use std::cell::RefCell; // Combines two hashes into one hash #[inline] @@ -41,6 +41,94 @@ pub fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +/// Maximum size for the thread-local hash buffer before truncation (4MB = 524,288 u64 elements). +/// The goal of this is to avoid unbounded memory growth that would appear as a memory leak. +/// We allow temporary allocations beyond this size, but after use the buffer is truncated +/// to this size. +const MAX_BUFFER_SIZE: usize = 524_288; + +thread_local! { + /// Thread-local buffer for hash computations to avoid repeated allocations. + /// The buffer is reused across calls and truncated if it exceeds MAX_BUFFER_SIZE. + /// Defaults to a capacity of 8192 u64 elements which is the default batch size. + /// This corresponds to 64KB of memory. + static HASH_BUFFER: RefCell> = const { RefCell::new(Vec::new()) }; +} + +/// Creates hashes for the given arrays using a thread-local buffer, then calls the provided callback +/// with an immutable reference to the computed hashes. +/// +/// This function manages a thread-local buffer to avoid repeated allocations. The buffer is automatically +/// truncated if it exceeds `MAX_BUFFER_SIZE` after use. +/// +/// # Arguments +/// * `arrays` - The arrays to hash (must contain at least one array) +/// * `random_state` - The random state for hashing +/// * `callback` - A function that receives an immutable reference to the hash slice and returns a result +/// +/// # Errors +/// Returns an error if: +/// - No arrays are provided +/// - The function is called reentrantly (i.e., the callback invokes `with_hashes` again on the same thread) +/// - The function is called during or after thread destruction +/// +/// # Example +/// ```ignore +/// use datafusion_common::hash_utils::{with_hashes, RandomState}; +/// use arrow::array::{Int32Array, ArrayRef}; +/// use std::sync::Arc; +/// +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); +/// let random_state = RandomState::new(); +/// +/// let result = with_hashes([&array], &random_state, |hashes| { +/// // Use the hashes here +/// Ok(hashes.len()) +/// })?; +/// ``` +pub fn with_hashes( + arrays: I, + random_state: &RandomState, + callback: F, +) -> Result +where + I: IntoIterator, + T: AsDynArray, + F: FnOnce(&[u64]) -> Result, +{ + // Peek at the first array to determine buffer size without fully collecting + let mut iter = arrays.into_iter().peekable(); + + // Get the required size from the first array + let required_size = match iter.peek() { + Some(arr) => arr.as_dyn_array().len(), + None => return _internal_err!("with_hashes requires at least one array"), + }; + + HASH_BUFFER.try_with(|cell| { + let mut buffer = cell.try_borrow_mut() + .map_err(|_| _internal_datafusion_err!("with_hashes cannot be called reentrantly on the same thread"))?; + + // Ensure buffer has sufficient length, clearing old values + buffer.clear(); + buffer.resize(required_size, 0); + + // Create hashes in the buffer - this consumes the iterator + create_hashes(iter, random_state, &mut buffer[..required_size])?; + + // Execute the callback with an immutable slice + let result = callback(&buffer[..required_size])?; + + // Cleanup: truncate if buffer grew too large + if buffer.capacity() > MAX_BUFFER_SIZE { + buffer.truncate(MAX_BUFFER_SIZE); + buffer.shrink_to_fit(); + } + + Ok(result) + }).map_err(|_| _internal_datafusion_err!("with_hashes cannot access thread-local storage during or after thread destruction"))? +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { @@ -74,7 +162,7 @@ macro_rules! hash_value { })+ }; } -hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); +hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, u128); hash_value!(bool, str, [u8], IntervalDayTime, IntervalMonthDayNano); macro_rules! hash_float_value { @@ -181,6 +269,127 @@ fn hash_array( } } +/// Hash a StringView or BytesView array +/// +/// Templated to optimize inner loop based on presence of nulls and external buffers. +/// +/// HAS_NULLS: do we have to check null in the inner loop +/// HAS_BUFFERS: if true, array has external buffers; if false, all strings are inlined/ less then 12 bytes +/// REHASH: if true, combining with existing hash, otherwise initializing +#[inline(never)] +fn hash_string_view_array_inner< + T: ByteViewType, + const HAS_NULLS: bool, + const HAS_BUFFERS: bool, + const REHASH: bool, +>( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) { + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + + let buffers = array.data_buffers(); + let view_bytes = |view_len: u32, view: u128| { + let view = ByteView::from(view); + let offset = view.offset as usize; + // SAFETY: view is a valid view as it came from the array + unsafe { + let data = buffers.get_unchecked(view.buffer_index as usize); + data.get_unchecked(offset..offset + view_len as usize) + } + }; + + let hashes_and_views = hashes_buffer.iter_mut().zip(array.views().iter()); + for (i, (hash, &v)) in hashes_and_views.enumerate() { + if HAS_NULLS && array.is_null(i) { + continue; + } + let view_len = v as u32; + // all views are inlined, no need to access external buffers + if !HAS_BUFFERS || view_len <= 12 { + if REHASH { + *hash = combine_hashes(v.hash_one(random_state), *hash); + } else { + *hash = v.hash_one(random_state); + } + continue; + } + // view is not inlined, so we need to hash the bytes as well + let value = view_bytes(view_len, v); + if REHASH { + *hash = combine_hashes(value.hash_one(random_state), *hash); + } else { + *hash = value.hash_one(random_state); + } + } +} + +/// Builds hash values for array views and writes them into `hashes_buffer` +/// If `rehash==true` this combines the previous hash value in the buffer +/// with the new hash using `combine_hashes` +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_generic_byte_view_array( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) { + // instantiate the correct version based on presence of nulls and external buffers + match ( + array.null_count() != 0, + !array.data_buffers().is_empty(), + rehash, + ) { + // no nulls or buffers ==> hash the inlined views directly + // don't call the inner function as Rust seems better able to inline this simpler code (2-3% faster) + (false, false, false) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + *hash = view.hash_one(random_state); + } + } + (false, false, true) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + *hash = combine_hashes(view.hash_one(random_state), *hash); + } + } + (false, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + /// Helper function to update hash for a dictionary key if the value is valid #[cfg(not(feature = "force_hash_collisions"))] #[inline] @@ -329,6 +538,40 @@ where Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array( + array: &UnionArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + use std::collections::HashMap; + + let DataType::Union(union_fields, _mode) = array.data_type() else { + unreachable!() + }; + + let mut child_hashes = HashMap::with_capacity(union_fields.len()); + + for (type_id, _field) in union_fields.iter() { + let child = array.child(type_id); + let mut child_hash_buffer = vec![0; child.len()]; + create_hashes([child], random_state, &mut child_hash_buffer)?; + + child_hashes.insert(type_id, child_hash_buffer); + } + + #[expect(clippy::needless_range_loop)] + for i in 0..array.len() { + let type_id = array.type_id(i); + let child_offset = array.value_offset(i); + + let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); + hashes_buffer[i] = combine_hashes(hashes_buffer[i], child_hash[child_offset]); + } + + Ok(()) +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, @@ -362,6 +605,76 @@ fn hash_fixed_list_array( Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + // We find the relevant runs that cover potentially sliced arrays, so we can only hash those + // values. Then we find the runs that refer to the original runs and ensure that we apply + // hashes correctly to the sliced, whether sliced at the start, end, or both. + let array_offset = array.offset(); + let array_len = array.len(); + + if array_len == 0 { + return Ok(()); + } + + let run_ends = array.run_ends(); + let run_ends_values = run_ends.values(); + let values = array.values(); + + let start_physical_index = array.get_start_physical_index(); + // get_end_physical_index returns the inclusive last index, but we need the exclusive range end + // for the operations we use below. + let end_physical_index = array.get_end_physical_index() + 1; + + let sliced_values = values.slice( + start_physical_index, + end_physical_index - start_physical_index, + ); + let mut values_hashes = vec![0u64; sliced_values.len()]; + create_hashes( + std::slice::from_ref(&sliced_values), + random_state, + &mut values_hashes, + )?; + + let mut start_in_slice = 0; + for (adjusted_physical_index, &absolute_run_end) in run_ends_values + [start_physical_index..end_physical_index] + .iter() + .enumerate() + { + let is_null_value = sliced_values.is_null(adjusted_physical_index); + let absolute_run_end = absolute_run_end.as_usize(); + + let end_in_slice = (absolute_run_end - array_offset).min(array_len); + + if rehash { + if !is_null_value { + let value_hash = values_hashes[adjusted_physical_index]; + for hash in hashes_buffer + .iter_mut() + .take(end_in_slice) + .skip(start_in_slice) + { + *hash = combine_hashes(value_hash, *hash); + } + } + } else { + let value_hash = values_hashes[adjusted_physical_index]; + hashes_buffer[start_in_slice..end_in_slice].fill(value_hash); + } + + start_in_slice = end_in_slice; + } + + Ok(()) +} + /// Internal helper function that hashes a single array and either initializes or combines /// the hash values in the buffer. #[cfg(not(feature = "force_hash_collisions"))] @@ -376,10 +689,10 @@ fn hash_single_array( DataType::Null => hash_null(random_state, hashes_buffer, rehash), DataType::Boolean => hash_array(&as_boolean_array(array)?, random_state, hashes_buffer, rehash), DataType::Utf8 => hash_array(&as_string_array(array)?, random_state, hashes_buffer, rehash), - DataType::Utf8View => hash_array(&as_string_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8View => hash_generic_byte_view_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), DataType::LargeUtf8 => hash_array(&as_largestring_array(array), random_state, hashes_buffer, rehash), DataType::Binary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), - DataType::BinaryView => hash_array(&as_binary_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::BinaryView => hash_generic_byte_view_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), DataType::LargeBinary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), DataType::FixedSizeBinary(_) => { let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); @@ -409,6 +722,14 @@ fn hash_single_array( let array = as_fixed_size_list_array(array)?; hash_fixed_list_array(array, random_state, hashes_buffer)?; } + DataType::Union(_, _) => { + let array = as_union_array(array)?; + hash_union_array(array, random_state, hashes_buffer)?; + } + DataType::RunEndEncoded(_, _) => downcast_run_array! { + array => hash_run_array(array, random_state, hashes_buffer, rehash)?, + _ => unreachable!() + } _ => { // This is internal because we should have caught this before. return _internal_err!( @@ -478,8 +799,8 @@ impl AsDynArray for &ArrayRef { pub fn create_hashes<'a, I, T>( arrays: I, random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> where I: IntoIterator, T: AsDynArray, @@ -522,7 +843,7 @@ mod tests { fn create_hashes_for_empty_fixed_size_lit() -> Result<()> { let empty_array = FixedSizeListBuilder::new(StringBuilder::new(), 1).finish(); let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; 0]; + let hashes_buff = &mut [0; 0]; let hashes = create_hashes( &[Arc::new(empty_array) as ArrayRef], &random_state, @@ -567,8 +888,6 @@ mod tests { let binary_array: ArrayRef = Arc::new(binary.iter().cloned().collect::<$ARRAY>()); - let ref_array: ArrayRef = - Arc::new(binary.iter().cloned().collect::()); let random_state = RandomState::with_seeds(0, 0, 0, 0); @@ -576,9 +895,6 @@ mod tests { create_hashes(&[binary_array], &random_state, &mut binary_hashes) .unwrap(); - let mut ref_hashes = vec![0; binary.len()]; - create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap(); - // Null values result in a zero hash, for (val, hash) in binary.iter().zip(binary_hashes.iter()) { match val { @@ -587,9 +903,6 @@ mod tests { } } - // same logical values should hash to the same hash value - assert_eq!(binary_hashes, ref_hashes); - // Same values should map to same hash values assert_eq!(binary[0], binary[5]); assert_eq!(binary[4], binary[6]); @@ -601,6 +914,7 @@ mod tests { } create_hash_binary!(binary_array, BinaryArray); + create_hash_binary!(large_binary_array, LargeBinaryArray); create_hash_binary!(binary_view_array, BinaryViewArray); #[test] @@ -677,6 +991,74 @@ mod tests { create_hash_string!(string_view_array, StringArray); create_hash_string!(dict_string_array, DictionaryArray); + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; array.len()]; + let hashes = create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + hashes_buff, + )?; + + assert_eq!(hashes.len(), 7); + assert_eq!(hashes[0], hashes[1]); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[3], hashes[4]); + assert_eq!(hashes[5], hashes[6]); + assert_ne!(hashes[0], hashes[2]); + assert_ne!(hashes[2], hashes[5]); + assert_ne!(hashes[0], hashes[5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_multi_column_hash_with_run_array() -> Result<()> { + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])); + let values = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let run_array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut one_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[Arc::clone(&int_array) as ArrayRef], + &random_state, + &mut one_col_hashes, + )?; + + let mut two_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[ + Arc::clone(&int_array) as ArrayRef, + Arc::clone(&run_array) as ArrayRef, + ], + &random_state, + &mut two_col_hashes, + )?; + + assert_eq!(one_col_hashes.len(), 7); + assert_eq!(two_col_hashes.len(), 7); + assert_ne!(one_col_hashes, two_col_hashes); + + let diff_0_vs_1_one_col = one_col_hashes[0] != one_col_hashes[1]; + let diff_0_vs_1_two_col = two_col_hashes[0] != two_col_hashes[1]; + assert_eq!(diff_0_vs_1_one_col, diff_0_vs_1_two_col); + + let diff_2_vs_3_one_col = one_col_hashes[2] != one_col_hashes[3]; + let diff_2_vs_3_two_col = two_col_hashes[2] != two_col_hashes[3]; + assert_eq!(diff_2_vs_3_one_col, diff_2_vs_3_two_col); + + Ok(()) + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] @@ -1000,4 +1382,297 @@ mod tests { assert_eq!(hashes1, hashes2); } + + #[test] + fn test_with_hashes() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that with_hashes produces the same results as create_hashes + let mut expected_hashes = vec![0; array.len()]; + create_hashes([&array], &random_state, &mut expected_hashes).unwrap(); + + let result = with_hashes([&array], &random_state, |hashes| { + assert_eq!(hashes.len(), 4); + // Verify hashes match expected values + assert_eq!(hashes, &expected_hashes[..]); + // Return a copy of the hashes + Ok(hashes.to_vec()) + }) + .unwrap(); + + // Verify callback result is returned correctly + assert_eq!(result, expected_hashes); + } + + #[test] + fn test_with_hashes_multi_column() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let str_array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test multi-column hashing + let mut expected_hashes = vec![0; int_array.len()]; + create_hashes( + [&int_array, &str_array], + &random_state, + &mut expected_hashes, + ) + .unwrap(); + + with_hashes([&int_array, &str_array], &random_state, |hashes| { + assert_eq!(hashes.len(), 3); + assert_eq!(hashes, &expected_hashes[..]); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_with_hashes_empty_arrays() { + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that passing no arrays returns an error + let empty: [&ArrayRef; 0] = []; + let result = with_hashes(empty, &random_state, |_hashes| Ok(())); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("requires at least one array") + ); + } + + #[test] + fn test_with_hashes_reentrancy() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array2: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Test that reentrant calls return an error instead of panicking + let result = with_hashes([&array], &random_state, |_hashes| { + // Try to call with_hashes again inside the callback + with_hashes([&array2], &random_state, |_inner_hashes| Ok(())) + }); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("reentrantly") || err_msg.contains("cannot be called"), + "Error message should mention reentrancy: {err_msg}", + ); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays() { + // logical array: [int(5), str("foo"), int(10), int(5)] + let int_array = Int32Array::from(vec![Some(5), None, Some(10), Some(5)]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0_i8, 1, 0, 0].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // Rows 0 and 3 both have type_id=0 (int) with value 5 + assert_eq!(hashes[0], hashes[3]); + // Row 0 (int 5) vs Row 2 (int 10) - different values + assert_ne!(hashes[0], hashes[2]); + // Row 0 (int) vs Row 1 (string) - different types + assert_ne!(hashes[0], hashes[1]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays_with_nulls() { + // logical array: [int(5), str("foo"), int(null), str(null)] + let int_array = Int32Array::from(vec![Some(5), None, None, None]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0, 1, 0, 1].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // row 2 (int null) and row 3 (str null) should have the same hash + // because they are both null values + assert_eq!(hashes[2], hashes[3]); + + // row 0 (int 5) vs row 2 (int null) - different (value vs null) + assert_ne!(hashes[0], hashes[2]); + + // row 1 (str "foo") vs row 3 (str null) - different (value vs null) + assert_ne!(hashes[1], hashes[3]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_dense_union_arrays() { + // creates a dense union array with int and string types + // [67, "norm", 100, "macdonald", 67] + let int_array = Int32Array::from(vec![67, 100, 67]); + let str_array = StringArray::from(vec!["norm", "macdonald"]); + + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, false))), + (1, Arc::new(Field::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // 67 vs "norm" + assert_ne!(hashes[0], hashes[1]); + // 67 vs 100 + assert_ne!(hashes[0], hashes[2]); + // "norm" vs "macdonald" + assert_ne!(hashes[1], hashes[3]); + // 100 vs "macdonald" + assert_ne!(hashes[2], hashes[3]); + // 67 vs 67 + assert_eq!(hashes[0], hashes[4]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut full_hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut full_hashes, + )?; + + let array_ref: ArrayRef = Arc::clone(&array) as ArrayRef; + let sliced_array = array_ref.slice(2, 3); + + let mut sliced_hashes = vec![0; sliced_array.len()]; + create_hashes( + std::slice::from_ref(&sliced_array), + &random_state, + &mut sliced_hashes, + )?; + + assert_eq!(sliced_hashes.len(), 3); + assert_eq!(sliced_hashes[0], sliced_hashes[1]); + assert_eq!(sliced_hashes[1], sliced_hashes[2]); + assert_eq!(&sliced_hashes, &full_hashes[2..5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![2, 4, 6])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut hashes, + )?; + + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[0], 0); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[2], 0); + assert_eq!(hashes[4], hashes[5]); + assert_ne!(hashes[4], 0); + assert_ne!(hashes[0], hashes[4]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls_multicolumn() -> Result<()> { + let primitive_array = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![1, 2, 3])); + let run_array = + Arc::new(RunArray::try_new(&run_ends, run_values.as_ref()).unwrap()); + let second_col = Arc::new(Int32Array::from(vec![100, 200, 300])); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + let mut primitive_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&primitive_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut primitive_hashes, + )?; + + let mut run_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&run_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut run_hashes, + )?; + + assert_eq!(primitive_hashes, run_hashes); + + Ok(()) + } } diff --git a/datafusion/common/src/instant.rs b/datafusion/common/src/instant.rs index 42f21c061c0c2..a5dfb28292581 100644 --- a/datafusion/common/src/instant.rs +++ b/datafusion/common/src/instant.rs @@ -22,7 +22,7 @@ /// under `wasm` feature gate. It provides the same API as [`std::time::Instant`]. pub type Instant = web_time::Instant; -#[allow(clippy::disallowed_types)] +#[expect(clippy::disallowed_types)] #[cfg(not(target_family = "wasm"))] /// DataFusion wrapper around [`std::time::Instant`]. This is only a type alias. pub type Instant = std::time::Instant; diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 549c265024f91..3bec9bd35cbd0 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -23,17 +23,14 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] +#![deny(clippy::allow_attributes)] mod column; mod dfschema; mod functional_dependencies; mod join_type; mod param_value; -#[cfg(feature = "pyarrow")] -mod pyarrow; mod schema_reference; mod table_reference; mod unnest; @@ -69,21 +66,24 @@ pub mod utils; pub use arrow; pub use column::Column; pub use dfschema::{ - qualified_name, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, + DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, qualified_name, }; pub use diagnostic::Diagnostic; +pub use display::human_readable::{ + human_readable_count, human_readable_duration, human_readable_size, units, +}; pub use error::{ - field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, - SharedResult, + DataFusionError, Result, SchemaError, SharedResult, field_not_found, + unqualified_field_not_found, }; pub use file_options::file_type::{ - GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, - DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, GetExt, }; pub use functional_dependencies::{ + Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, aggregate_functional_dependencies, get_required_group_by_exprs_indices, - get_target_functional_dependencies, Constraint, Constraints, Dependency, - FunctionalDependence, FunctionalDependencies, + get_target_functional_dependencies, }; use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; @@ -105,9 +105,9 @@ pub use utils::project_schema; // https://github.com/rust-lang/rust/pull/52234#issuecomment-976702997 #[doc(hidden)] pub use error::{ - _config_datafusion_err, _exec_datafusion_err, _internal_datafusion_err, - _not_impl_datafusion_err, _plan_datafusion_err, _resources_datafusion_err, - _substrait_datafusion_err, + _config_datafusion_err, _exec_datafusion_err, _ffi_datafusion_err, + _internal_datafusion_err, _not_impl_datafusion_err, _plan_datafusion_err, + _resources_datafusion_err, _substrait_datafusion_err, }; // The HashMap and HashSet implementations that should be used as the uniform defaults @@ -139,10 +139,10 @@ macro_rules! downcast_value { // Not public API. #[doc(hidden)] pub mod __private { - use crate::error::_internal_datafusion_err; use crate::Result; + use crate::error::_internal_datafusion_err; use arrow::array::Array; - use std::any::{type_name, Any}; + use std::any::{Any, type_name}; #[doc(hidden)] pub trait DowncastArrayHelper { @@ -193,7 +193,7 @@ mod tests { assert_starts_with( error.to_string(), - "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray" + "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray", ); } diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs index 3a10cc2b42f9f..eb687bde07d0b 100644 --- a/datafusion/common/src/metadata.rs +++ b/datafusion/common/src/metadata.rs @@ -17,10 +17,10 @@ use std::{collections::BTreeMap, sync::Arc}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use hashbrown::HashMap; -use crate::{error::_plan_err, DataFusionError, ScalarValue}; +use crate::{DataFusionError, ScalarValue, error::_plan_err}; /// A [`ScalarValue`] with optional [`FieldMetadata`] #[derive(Debug, Clone)] @@ -320,6 +320,16 @@ impl FieldMetadata { field.with_metadata(self.to_hashmap()) } + + /// Updates the metadata on the FieldRef with this metadata, if it is not empty. + pub fn add_to_field_ref(&self, mut field_ref: FieldRef) -> FieldRef { + if self.inner.is_empty() { + return field_ref; + } + + Arc::make_mut(&mut field_ref).set_metadata(self.to_hashmap()); + field_ref + } } impl From<&Field> for FieldMetadata { diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index d43816f75b0ed..086d96e85230d 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{Result, _plan_err}; +use crate::error::{_plan_err, Result}; use arrow::{ - array::{new_null_array, Array, ArrayRef, StructArray}, - compute::{cast_with_options, CastOptions}, + array::{Array, ArrayRef, StructArray, new_null_array}, + compute::{CastOptions, cast_with_options}, datatypes::{DataType::Struct, Field, FieldRef}, }; use std::sync::Arc; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index ebf68e4dd210d..0fac6b529eb0f 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -16,7 +16,7 @@ // under the License. use crate::error::{_plan_datafusion_err, _plan_err}; -use crate::metadata::{check_metadata_with_storage_equal, ScalarAndMetadata}; +use crate::metadata::{ScalarAndMetadata, check_metadata_with_storage_equal}; use crate::{Result, ScalarValue}; use arrow::datatypes::{DataType, Field, FieldRef}; use std::collections::HashMap; diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs index 48750e3c995c4..5a7598ea1f299 100644 --- a/datafusion/common/src/pruning.rs +++ b/datafusion/common/src/pruning.rs @@ -135,6 +135,10 @@ pub trait PruningStatistics { /// This feeds into [`CompositePruningStatistics`] to allow pruning /// with filters that depend both on partition columns and data columns /// (e.g. `WHERE partition_col = data_col`). +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. Use `replace_columns_with_literals` from `datafusion-physical-expr-adapter` to substitute partition column values before pruning. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." +)] #[derive(Clone)] pub struct PartitionPruningStatistics { /// Values for each column for each container. @@ -156,6 +160,7 @@ pub struct PartitionPruningStatistics { partition_schema: SchemaRef, } +#[expect(deprecated)] impl PartitionPruningStatistics { /// Create a new instance of [`PartitionPruningStatistics`]. /// @@ -169,6 +174,36 @@ impl PartitionPruningStatistics { /// This must **not** be the schema of the entire file or table: /// instead it must only be the schema of the partition columns, /// in the same order as the values in `partition_values`. + /// + /// # Example + /// + /// To create [`PartitionPruningStatistics`] for two partition columns `a` and `b`, + /// for three containers like this: + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + /// | 5 | 6 | + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_common::ScalarValue; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::pruning::PartitionPruningStatistics; + /// + /// let partition_values = vec![ + /// vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + /// vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + /// vec![ScalarValue::from(5i32), ScalarValue::from(6i32)], + /// ]; + /// let partition_fields = vec![ + /// Arc::new(Field::new("a", DataType::Int32, false)), + /// Arc::new(Field::new("b", DataType::Int32, false)), + /// ]; + /// let partition_stats = + /// PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap(); + /// ``` pub fn try_new( partition_values: Vec>, partition_fields: Vec, @@ -202,6 +237,7 @@ impl PartitionPruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for PartitionPruningStatistics { fn min_values(&self, column: &Column) -> Option { let index = self.partition_schema.index_of(column.name()).ok()?; @@ -245,7 +281,7 @@ impl PruningStatistics for PartitionPruningStatistics { match acc { None => Some(Some(eq_result)), Some(acc_array) => { - arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + arrow::compute::kernels::boolean::or_kleene(&acc_array, &eq_result) .map(Some) .ok() } @@ -409,10 +445,15 @@ impl PruningStatistics for PrunableStatistics { /// the first one is returned without any regard for completeness or accuracy. /// That is: if the first statistics has information for a column, even if it is incomplete, /// that is returned even if a later statistics has more complete information. +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. It may be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first. Please open an issue if you have a use case for it." +)] pub struct CompositePruningStatistics { pub statistics: Vec>, } +#[expect(deprecated)] impl CompositePruningStatistics { /// Create a new instance of [`CompositePruningStatistics`] from /// a vector of [`PruningStatistics`]. @@ -427,6 +468,7 @@ impl CompositePruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for CompositePruningStatistics { fn min_values(&self, column: &Column) -> Option { for stats in &self.statistics { @@ -483,18 +525,25 @@ impl PruningStatistics for CompositePruningStatistics { } #[cfg(test)] +#[expect(deprecated)] mod tests { use crate::{ - cast::{as_int32_array, as_uint64_array}, ColumnStatistics, + cast::{as_int32_array, as_uint64_array}, }; use super::*; use arrow::datatypes::{DataType, Field}; use std::sync::Arc; - #[test] - fn test_partition_pruning_statistics() { + /// return a PartitionPruningStatistics for two columns 'a' and 'b' + /// and the following stats + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + fn partition_pruning_statistics_setup() -> PartitionPruningStatistics { let partition_values = vec![ vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], @@ -503,9 +552,12 @@ mod tests { Arc::new(Field::new("a", DataType::Int32, false)), Arc::new(Field::new("b", DataType::Int32, false)), ]; - let partition_stats = - PartitionPruningStatistics::try_new(partition_values, partition_fields) - .unwrap(); + PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap() + } + + #[test] + fn test_partition_pruning_statistics() { + let partition_stats = partition_pruning_statistics_setup(); let column_a = Column::new_unqualified("a"); let column_b = Column::new_unqualified("b"); @@ -560,6 +612,85 @@ mod tests { assert_eq!(partition_stats.num_containers(), 2); } + #[test] + fn test_partition_pruning_statistics_multiple_positive_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, so they both only contain values from 1 and 3 + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(3i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, true]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_multiple_negative_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, + // so the first contains ONLY values from 1,2 + // but the second does not + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(2i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_null_in_values() { + let partition_values = vec![ + vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ], + vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ScalarValue::from(6i32), + ], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + Arc::new(Field::new("c", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + let column_c = Column::new_unqualified("c"); + + let values_a = HashSet::from([ScalarValue::from(1i32), ScalarValue::Int32(None)]); + let contained_a = partition_stats.contained(&column_a, &values_a).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_value(true); + builder.append_null(); + let expected_contained_a = builder.finish(); + assert_eq!(contained_a, expected_contained_a); + + // First match creates a NULL boolean array + // The accumulator should update the value to true for the second value + let values_b = HashSet::from([ScalarValue::Int32(None), ScalarValue::from(5i32)]); + let contained_b = partition_stats.contained(&column_b, &values_b).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_null(); + builder.append_value(true); + let expected_contained_b = builder.finish(); + assert_eq!(contained_b, expected_contained_b); + + // All matches are null, contained should return None + let values_c = HashSet::from([ScalarValue::Int32(None)]); + let contained_c = partition_stats.contained(&column_c, &values_c); + assert!(contained_c.is_none()); + } + #[test] fn test_partition_pruning_statistics_empty() { let partition_values = vec![]; diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs deleted file mode 100644 index 18c6739735ff7..0000000000000 --- a/datafusion/common/src/pyarrow.rs +++ /dev/null @@ -1,169 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Conversions between PyArrow and DataFusion types - -use arrow::array::{Array, ArrayData}; -use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use pyo3::exceptions::PyException; -use pyo3::prelude::PyErr; -use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyResult, Python}; - -use crate::{DataFusionError, ScalarValue}; - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - -impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { - let py = value.py(); - let typ = value.getattr("type")?; - let val = value.call_method0("as_py")?; - - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [val])?; - let array = factory.call1((args, typ))?; - - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0)?; - - Ok(scalar) - } -} - -impl ToPyArrow for ScalarValue { - fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult> { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - let pyscalar = pyarray.call_method1("__getitem__", (0,))?; - - Ok(pyscalar) - } -} - -impl<'source> FromPyObject<'source> for ScalarValue { - fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult { - Self::from_pyarrow_bound(value) - } -} - -impl<'source> IntoPyObject<'source> for ScalarValue { - type Target = PyAny; - - type Output = Bound<'source, Self::Target>; - - type Error = PyErr; - - fn into_pyobject(self, py: Python<'source>) -> Result { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - pyarray.call_method1("__getitem__", (0,)) - } -} - -#[cfg(test)] -mod tests { - use pyo3::ffi::c_str; - use pyo3::py_run; - use pyo3::types::PyDict; - use pyo3::Python; - - use super::*; - - fn init_python() { - Python::initialize(); - Python::attach(|py| { - if py.run(c_str!("import pyarrow"), None, None).is_err() { - let locals = PyDict::new(py); - py.run( - c_str!( - "import sys; executable = sys.executable; python_path = sys.path" - ), - None, - Some(&locals), - ) - .expect("Couldn't get python info"); - let executable = locals.get_item("executable").unwrap(); - let executable: String = executable.extract().unwrap(); - - let python_path = locals.get_item("python_path").unwrap(); - let python_path: Vec = python_path.extract().unwrap(); - - panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ - HINT: try `pip install pyarrow`\n\ - NOTE: On Mac OS, you must compile against a Framework Python \ - (default in python.org installers and brew, but not pyenv)\n\ - NOTE: On Mac OS, PYO3 might point to incorrect Python library \ - path when using virtual environments. Try \ - `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n") - } - }) - } - - #[test] - fn test_roundtrip() { - init_python(); - - let example_scalars = [ - ScalarValue::Boolean(Some(true)), - ScalarValue::Int32(Some(23)), - ScalarValue::Float64(Some(12.34)), - ScalarValue::from("Hello!"), - ScalarValue::Date32(Some(1234)), - ]; - - Python::attach(|py| { - for scalar in example_scalars.iter() { - let result = - ScalarValue::from_pyarrow_bound(&scalar.to_pyarrow(py).unwrap()) - .unwrap(); - assert_eq!(scalar, &result); - } - }); - } - - #[test] - fn test_py_scalar() -> PyResult<()> { - init_python(); - - Python::attach(|py| -> PyResult<()> { - let scalar_float = ScalarValue::Float64(Some(12.34)); - let py_float = scalar_float - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_float, "assert py_float == 12.34"); - - let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); - let py_string = scalar_string - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_string, "assert py_string == 'Hello!'"); - - Ok(()) - }) - } -} diff --git a/datafusion/common/src/rounding.rs b/datafusion/common/src/rounding.rs index 95eefd3235b5f..1796143d7cf1a 100644 --- a/datafusion/common/src/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -47,7 +47,7 @@ extern crate libc; any(target_arch = "x86_64", target_arch = "aarch64"), not(target_os = "windows") ))] -extern "C" { +unsafe extern "C" { fn fesetround(round: i32); fn fegetround() -> i32; } diff --git a/datafusion/common/src/scalar/cache.rs b/datafusion/common/src/scalar/cache.rs index f1476a518774b..5b1ad4e4ede01 100644 --- a/datafusion/common/src/scalar/cache.rs +++ b/datafusion/common/src/scalar/cache.rs @@ -20,10 +20,10 @@ use std::iter::repeat_n; use std::sync::{Arc, LazyLock, Mutex}; -use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, ArrayRef, PrimitiveArray, new_null_array}; use arrow::datatypes::{ - ArrowDictionaryKeyType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowDictionaryKeyType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; /// Maximum number of rows to cache to be conservative on memory usage diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs index 8cb446b1c9211..599c2523cd2c7 100644 --- a/datafusion/common/src/scalar/consts.rs +++ b/datafusion/common/src/scalar/consts.rs @@ -17,24 +17,36 @@ // Constants defined for scalar construction. +// Next F16 value above π (upper bound) +pub(super) const PI_UPPER_F16: half::f16 = half::f16::from_bits(0x4249); + // Next f32 value above π (upper bound) pub(super) const PI_UPPER_F32: f32 = std::f32::consts::PI.next_up(); // Next f64 value above π (upper bound) pub(super) const PI_UPPER_F64: f64 = std::f64::consts::PI.next_up(); +// Next f16 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F16: half::f16 = half::f16::from_bits(0xC249); + // Next f32 value below -π (lower bound) pub(super) const NEGATIVE_PI_LOWER_F32: f32 = (-std::f32::consts::PI).next_down(); // Next f64 value below -π (lower bound) pub(super) const NEGATIVE_PI_LOWER_F64: f64 = (-std::f64::consts::PI).next_down(); +// Next f16 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F16: half::f16 = half::f16::from_bits(0x3E49); + // Next f32 value above π/2 (upper bound) pub(super) const FRAC_PI_2_UPPER_F32: f32 = std::f32::consts::FRAC_PI_2.next_up(); // Next f64 value above π/2 (upper bound) pub(super) const FRAC_PI_2_UPPER_F64: f64 = std::f64::consts::FRAC_PI_2.next_up(); +// Next f32 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F16: half::f16 = half::f16::from_bits(0xBE49); + // Next f32 value below -π/2 (lower bound) pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = (-std::f32::consts::FRAC_PI_2).next_down(); diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index fadd2e41eaba4..e4e048ad3c0d8 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -33,64 +33,162 @@ use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; +use crate::assert_or_internal_err; use crate::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array, - as_date64_array, as_decimal128_array, as_decimal256_array, as_decimal32_array, - as_decimal64_array, as_dictionary_array, as_duration_microsecond_array, + as_date64_array, as_decimal32_array, as_decimal64_array, as_decimal128_array, + as_decimal256_array, as_dictionary_array, as_duration_microsecond_array, as_duration_millisecond_array, as_duration_nanosecond_array, as_duration_second_array, as_fixed_size_binary_array, as_fixed_size_list_array, - as_float16_array, as_float32_array, as_float64_array, as_int16_array, as_int32_array, - as_int64_array, as_int8_array, as_interval_dt_array, as_interval_mdn_array, + as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, + as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, as_large_binary_array, as_large_list_array, as_large_string_array, as_string_array, as_string_view_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, - as_timestamp_second_array, as_uint16_array, as_uint32_array, as_uint64_array, - as_uint8_array, as_union_array, + as_timestamp_second_array, as_uint8_array, as_uint16_array, as_uint32_array, + as_uint64_array, as_union_array, }; -use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; +use crate::error::{_exec_err, _internal_err, _not_impl_err, DataFusionError, Result}; use crate::format::DEFAULT_CAST_OPTIONS; use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - new_empty_array, new_null_array, Array, ArrayData, ArrayRef, ArrowNativeTypeOp, - ArrowPrimitiveType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + Array, ArrayData, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, + BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, - FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, - Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, - LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, - PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, UnionArray, + FixedSizeBinaryBuilder, FixedSizeListArray, Float16Array, Float32Array, Float64Array, + GenericListArray, Int8Array, Int16Array, Int32Array, Int64Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray, + MutableArrayData, OffsetSizeTrait, PrimitiveArray, Scalar, StringArray, + StringViewArray, StringViewBuilder, StructArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, UnionArray, + new_empty_array, new_null_array, }; use arrow::buffer::{BooleanBuffer, ScalarBuffer}; -use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; +use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; use arrow::compute::kernels::numeric::{ add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, }; use arrow::datatypes::{ - i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType, - ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, - Decimal32Type, Decimal64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionFields, - UnionMode, DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, + Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, UnionFields, UnionMode, i256, validate_decimal_precision_and_scale, }; -use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; use chrono::{Duration, NaiveDate}; use half::f16; pub use struct_builder::ScalarStructBuilder; +const SECONDS_PER_DAY: i64 = 86_400; +const MILLIS_PER_DAY: i64 = SECONDS_PER_DAY * 1_000; +const MICROS_PER_DAY: i64 = MILLIS_PER_DAY * 1_000; +const NANOS_PER_DAY: i64 = MICROS_PER_DAY * 1_000; +const MICROS_PER_MILLISECOND: i64 = 1_000; +const NANOS_PER_MILLISECOND: i64 = 1_000_000; + +/// Returns the multiplier that converts the input date representation into the +/// desired timestamp unit, if the conversion requires a multiplication that can +/// overflow an `i64`. +pub fn date_to_timestamp_multiplier( + source_type: &DataType, + target_type: &DataType, +) -> Option { + let DataType::Timestamp(target_unit, _) = target_type else { + return None; + }; + + // Only `Timestamp` target types have a time unit; otherwise no + // multiplier applies (handled above). The function returns `Some(m)` + // when converting the `source_type` to `target_type` requires a + // multiplication that could overflow `i64`. It returns `None` when + // the conversion is a division or otherwise doesn't require a + // multiplication (e.g. Date64 -> Second). + match source_type { + // Date32 stores days since epoch. Converting to any timestamp + // unit requires multiplying by the per-day factor (seconds, + // milliseconds, microseconds, nanoseconds). + DataType::Date32 => Some(match target_unit { + TimeUnit::Second => SECONDS_PER_DAY, + TimeUnit::Millisecond => MILLIS_PER_DAY, + TimeUnit::Microsecond => MICROS_PER_DAY, + TimeUnit::Nanosecond => NANOS_PER_DAY, + }), + + // Date64 stores milliseconds since epoch. Converting to + // seconds is a division (no multiplication), so return `None`. + // Converting to milliseconds is 1:1 (multiplier 1). Converting + // to micro/nano requires multiplying by 1_000 / 1_000_000. + DataType::Date64 => match target_unit { + TimeUnit::Second => None, + // Converting Date64 (ms since epoch) to millisecond timestamps + // is an identity conversion and does not require multiplication. + // Returning `None` indicates no multiplication-based overflow + // check is necessary. + TimeUnit::Millisecond => None, + TimeUnit::Microsecond => Some(MICROS_PER_MILLISECOND), + TimeUnit::Nanosecond => Some(NANOS_PER_MILLISECOND), + }, + + _ => None, + } +} + +/// Ensures the provided value can be represented as a timestamp with the given +/// multiplier. Returns an [`DataFusionError::Execution`] when the converted +/// value would overflow the timestamp range. +pub fn ensure_timestamp_in_bounds( + value: i64, + multiplier: i64, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + if multiplier <= 1 { + return Ok(()); + } + + if value.checked_mul(multiplier).is_none() { + let target = format_timestamp_type_for_error(target_type); + _exec_err!( + "Cannot cast {} value {} to {}: converted value exceeds the representable i64 range", + source_type, + value, + target + ) + } else { + Ok(()) + } +} + +/// Format a `DataType::Timestamp` into a short, stable string used in +/// user-facing error messages. +pub(crate) fn format_timestamp_type_for_error(target_type: &DataType) -> String { + match target_type { + DataType::Timestamp(unit, _) => { + let s = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "ms", + TimeUnit::Microsecond => "us", + TimeUnit::Nanosecond => "ns", + }; + format!("Timestamp({s})") + } + other => format!("{other}"), + } +} + /// A dynamically typed, nullable single value. /// /// While an arrow [`Array`]) stores one or more values of the same type, in a @@ -622,11 +720,7 @@ impl PartialOrd for ScalarValue { (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { - v1.partial_cmp(v2) - } else { - None - } + if k1 == k2 { v1.partial_cmp(v2) } else { None } } (Dictionary(_, _), _) => None, (Null, Null) => Some(Ordering::Equal), @@ -646,7 +740,9 @@ fn first_array_for_list(arr: &dyn Array) -> ArrayRef { } else if let Some(arr) = arr.as_fixed_size_list_opt() { arr.value(0) } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + unreachable!( + "Since only List / LargeList / FixedSizeList are supported, this should never happen" + ) } } @@ -1055,13 +1151,8 @@ impl ScalarValue { /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result { - // make sure the precision and scale is valid - if precision <= DECIMAL128_MAX_PRECISION && scale.unsigned_abs() <= precision { - return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); - } - _internal_err!( - "Can not new a decimal type ScalarValue for precision {precision} and scale {scale}" - ) + Self::validate_decimal_or_internal_err::(precision, scale)?; + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) } /// Create a Null instance of ScalarValue for this datatype @@ -1153,7 +1244,7 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - // `ScalaValue::List` contains single element `ListArray`. + // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), @@ -1161,7 +1252,7 @@ impl ScalarValue { DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), - // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + // `ScalarValue::FixedSizeList` contains single element `FixedSizeList`. DataType::FixedSizeList(field_ref, fixed_length) => { ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( Arc::clone(field_ref), @@ -1241,6 +1332,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI pub fn new_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::PI)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), _ => _internal_err!("PI is not supported for data type: {}", datatype), @@ -1250,6 +1342,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI's upper bound pub fn new_pi_upper(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some(consts::PI_UPPER_F16))), DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), _ => { @@ -1261,6 +1354,9 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI's lower bound pub fn new_negative_pi_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::NEGATIVE_PI_LOWER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), _ => { @@ -1272,6 +1368,9 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::FRAC_PI_2_UPPER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), _ => { @@ -1283,6 +1382,9 @@ impl ScalarValue { // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some( + consts::NEGATIVE_FRAC_PI_2_LOWER_F16, + ))), DataType::Float32 => { Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32)) } @@ -1298,6 +1400,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI pub fn new_negative_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::PI)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), _ => _internal_err!("-PI is not supported for data type: {}", datatype), @@ -1307,6 +1410,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI/2 pub fn new_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), _ => _internal_err!("PI/2 is not supported for data type: {}", datatype), @@ -1316,6 +1420,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI/2 pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), _ => _internal_err!("-PI/2 is not supported for data type: {}", datatype), @@ -1325,6 +1430,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing infinity pub fn new_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), _ => { @@ -1336,6 +1442,7 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing negative infinity pub fn new_neg_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::NEG_INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), _ => { @@ -1359,7 +1466,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ZERO)), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), DataType::Decimal32(precision, scale) => { @@ -1574,16 +1681,14 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ONE)), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), DataType::Decimal32(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal32(Some(value), *precision, *scale) @@ -1592,12 +1697,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal64(Some(value), *precision, *scale) @@ -1606,12 +1709,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal128(Some(value), *precision, *scale) @@ -1620,12 +1721,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal256(Some(value), *precision, *scale) @@ -1648,16 +1747,14 @@ impl ScalarValue { DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::NEG_ONE)), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), DataType::Decimal32(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal32(Some(-value), *precision, *scale) @@ -1666,12 +1763,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal64(Some(-value), *precision, *scale) @@ -1680,12 +1775,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal128(Some(-value), *precision, *scale) @@ -1694,12 +1787,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, )?; - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow(*scale as u32) { Some(value) => { ScalarValue::Decimal256(Some(-value), *precision, *scale) @@ -1729,14 +1820,10 @@ impl ScalarValue { DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), DataType::Decimal32(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match 10_i32.checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal32(Some(value), *precision, *scale) @@ -1745,14 +1832,10 @@ impl ScalarValue { } } DataType::Decimal64(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i64::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal64(Some(value), *precision, *scale) @@ -1761,14 +1844,10 @@ impl ScalarValue { } } DataType::Decimal128(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i128::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal128(Some(value), *precision, *scale) @@ -1777,14 +1856,10 @@ impl ScalarValue { } } DataType::Decimal256(precision, scale) => { - if let Err(err) = validate_decimal_precision_and_scale::( + Self::validate_decimal_or_internal_err::( *precision, *scale, - ) { - return _internal_err!("Invalid precision and scale {err}"); - } - if *scale < 0 { - return _internal_err!("Negative scale is not supported"); - } + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); match i256::from(10).checked_pow((*scale + 1) as u32) { Some(value) => { ScalarValue::Decimal256(Some(value), *precision, *scale) @@ -1899,9 +1974,7 @@ impl ScalarValue { | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), - ScalarValue::Float16(Some(v)) => { - Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) - } + ScalarValue::Float16(Some(v)) => Ok(ScalarValue::Float16(Some(-v))), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), @@ -2022,6 +2095,7 @@ impl ScalarValue { let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } + /// Checked addition of `ScalarValue` /// /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code @@ -2293,18 +2367,20 @@ impl ScalarValue { macro_rules! build_array_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; @@ -2313,18 +2389,20 @@ impl ScalarValue { macro_rules! build_array_primitive_tz { ($ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v, _) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array.with_timezone_opt($TZ.clone())) } }}; @@ -2335,18 +2413,20 @@ impl ScalarValue { macro_rules! build_array_string { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; @@ -2648,71 +2728,6 @@ impl ScalarValue { Ok(array) } - fn build_decimal32_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal32Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal32Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal64_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal64Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal64Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal128_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(match value { - Some(val) => Decimal128Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal128Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) - } - - fn build_decimal256_array( - value: Option, - precision: u8, - scale: i8, - size: usize, - ) -> Result { - Ok(repeat_n(value, size) - .collect::() - .with_precision_and_scale(precision, scale)?) - } - /// Converts `Vec` where each element has type corresponding to /// `data_type`, to a single element [`ListArray`]. /// @@ -2868,18 +2883,35 @@ impl ScalarValue { /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { Ok(match self { - ScalarValue::Decimal32(e, precision, scale) => Arc::new( - ScalarValue::build_decimal32_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal32(Some(e), precision, scale) => Arc::new( + Decimal32Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal64(e, precision, scale) => Arc::new( - ScalarValue::build_decimal64_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal32(None, precision, scale) => { + new_null_array(&DataType::Decimal32(*precision, *scale), size) + } + ScalarValue::Decimal64(Some(e), precision, scale) => Arc::new( + Decimal64Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal128_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal64(None, precision, scale) => { + new_null_array(&DataType::Decimal64(*precision, *scale), size) + } + ScalarValue::Decimal128(Some(e), precision, scale) => Arc::new( + Decimal128Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal128(None, precision, scale) => { + new_null_array(&DataType::Decimal128(*precision, *scale), size) + } + ScalarValue::Decimal256(Some(e), precision, scale) => Arc::new( + Decimal256Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), + ScalarValue::Decimal256(None, precision, scale) => { + new_null_array(&DataType::Decimal256(*precision, *scale), size) + } + ScalarValue::Boolean(e) => match e { None => new_null_array(&DataType::Boolean, size), Some(true) => { @@ -2952,33 +2984,43 @@ impl ScalarValue { ) } ScalarValue::Utf8(e) => match e { - Some(value) => { - Arc::new(StringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(StringArray::new_repeated(value, size)), None => new_null_array(&DataType::Utf8, size), }, ScalarValue::Utf8View(e) => match e { Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat_n(value, size))) + let mut builder = + StringViewBuilder::with_capacity(size).with_deduplicate_strings(); + // Replace with upstream arrow-rs code when available: + // https://github.com/apache/arrow-rs/issues/9034 + for _ in 0..size { + builder.append_value(value); + } + let array = builder.finish(); + Arc::new(array) } None => new_null_array(&DataType::Utf8View, size), }, ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(LargeStringArray::new_repeated(value, size)), None => new_null_array(&DataType::LargeUtf8, size), }, ScalarValue::Binary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + Arc::new(BinaryArray::new_repeated(value.as_slice(), size)) + } None => new_null_array(&DataType::Binary, size), }, ScalarValue::BinaryView(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + let mut builder = + BinaryViewBuilder::with_capacity(size).with_deduplicate_strings(); + for _ in 0..size { + builder.append_value(value); + } + let array = builder.finish(); + Arc::new(array) + } None => new_null_array(&DataType::BinaryView, size), }, ScalarValue::FixedSizeBinary(s, e) => match e { @@ -2989,12 +3031,19 @@ impl ScalarValue { ) .unwrap(), ), - None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), + None => { + // TODO: Replace with FixedSizeBinaryArray::new_null once a fix for + // https://github.com/apache/arrow-rs/issues/8900 is in the used arrow-rs + // version. + let mut builder = FixedSizeBinaryBuilder::new(*s); + builder.append_nulls(size); + Arc::new(builder.finish()) + } }, ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), + Some(value) => { + Arc::new(LargeBinaryArray::new_repeated(value.as_slice(), size)) + } None => new_null_array(&DataType::LargeBinary, size), }, ScalarValue::List(arr) => { @@ -3153,10 +3202,7 @@ impl ScalarValue { .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; Arc::new(ar) } - None => { - let dt = self.data_type(); - new_null_array(&dt, size) - } + None => new_null_array(&DataType::Union(fields.clone(), *mode), size), }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) @@ -3650,11 +3696,26 @@ impl ScalarValue { target_type: &DataType, cast_options: &CastOptions<'static>, ) -> Result { + let source_type = self.data_type(); + if let Some(multiplier) = date_to_timestamp_multiplier(&source_type, target_type) + && let Some(value) = self.date_scalar_value_as_i64() + { + ensure_timestamp_in_bounds(value, multiplier, &source_type, target_type)?; + } + let scalar_array = self.to_array()?; let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } + fn date_scalar_value_as_i64(&self) -> Option { + match self { + ScalarValue::Date32(Some(value)) => Some(i64::from(*value)), + ScalarValue::Date64(Some(value)) => Some(*value), + _ => None, + } + } + fn eq_array_decimal32( array: &ArrayRef, index: usize, @@ -4354,6 +4415,20 @@ impl ScalarValue { _ => None, } } + + /// A thin wrapper on Arrow's validation that throws internal error if validation + /// fails. + fn validate_decimal_or_internal_err( + precision: u8, + scale: i8, + ) -> Result<()> { + validate_decimal_precision_and_scale::(precision, scale).map_err(|err| { + _internal_datafusion_err!( + "Decimal precision/scale invariant violated \ + (precision={precision}, scale={scale}): {err}" + ) + }) + } } /// Compacts the data of an `ArrayData` into a new `ArrayData`. @@ -5008,7 +5083,8 @@ mod tests { use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; use arrow::datatypes::{ - ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION, + ArrowNumericType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, Fields, + Float64Type, TimeUnit, }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; @@ -5041,6 +5117,52 @@ mod tests { assert_eq!(actual, &expected); } + #[test] + fn test_format_timestamp_type_for_error_and_bounds() { + // format helper + let ts_ns = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Nanosecond, + None, + )); + assert_eq!(ts_ns, "Timestamp(ns)"); + + let ts_us = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Microsecond, + None, + )); + assert_eq!(ts_us, "Timestamp(us)"); + + // ensure_timestamp_in_bounds: Date32 non-overflow + let ok = ensure_timestamp_in_bounds( + 1000, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(ok.is_ok()); + + // Date32 overflow -- known large day value (9999-12-31 -> 2932896) + let err = ensure_timestamp_in_bounds( + 2932896, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err.is_err()); + let msg = err.unwrap_err().to_string(); + assert!(msg.contains("Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range")); + + // Date64 overflow for ns (millis * 1_000_000) + let overflow_millis: i64 = (i64::MAX / NANOS_PER_MILLISECOND) + 1; + let err2 = ensure_timestamp_in_bounds( + overflow_millis, + NANOS_PER_MILLISECOND, + &DataType::Date64, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err2.is_err()); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); @@ -5172,6 +5294,18 @@ mod tests { assert_eq!(empty_array.len(), 0); } + /// See https://github.com/apache/datafusion/issues/18870 + #[test] + fn test_to_array_of_size_for_none_fsb() { + let sv = ScalarValue::FixedSizeBinary(5, None); + let result = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + assert_eq!(result.as_fixed_size_binary().values().len(), 10); + } + #[test] fn test_list_to_array_string() { let scalars = vec![ @@ -5527,7 +5661,10 @@ mod tests { .sub_checked(&int_value_2) .unwrap_err() .strip_backtrace(); - assert_eq!(err, "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808") + assert_eq!( + err, + "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808" + ) } #[test] @@ -5675,12 +5812,16 @@ mod tests { assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(decimal_value - .eq_array(&array, 5) - .expect("Failed to compare arrays")); + assert!( + decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays") + ); // test try from array assert_eq!( decimal_value, @@ -5725,18 +5866,24 @@ mod tests { assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); - assert!(ScalarValue::try_new_decimal128(1, 10, 2) - .unwrap() - .eq_array(&array, 0) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(2, 10, 2) - .unwrap() - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(3, 10, 2) - .unwrap() - .eq_array(&array, 2) - .expect("Failed to compare arrays")); + assert!( + ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2) + .expect("Failed to compare arrays") + ); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -6172,8 +6319,6 @@ mod tests { } #[test] - // despite clippy claiming they are useless, the code doesn't compile otherwise. - #[allow(clippy::useless_vec)] fn scalar_iter_to_array_boolean() { check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); @@ -6223,12 +6368,12 @@ mod tests { check_scalar_iter_binary!( Binary, BinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); check_scalar_iter_binary!( LargeBinary, LargeBinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); } @@ -6681,7 +6826,9 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), + !scalar + .eq_array(&array, other_index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -7606,7 +7753,6 @@ mod tests { } #[test] - #[allow(arithmetic_overflow)] // we want to test them fn test_scalar_negative_overflows() -> Result<()> { macro_rules! test_overflow_on_value { ($($val:expr),* $(,)?) => {$( @@ -8622,6 +8768,19 @@ mod tests { assert!(dense_scalar.is_null()); } + #[test] + fn cast_date_to_timestamp_overflow_returns_error() { + let scalar = ScalarValue::Date32(Some(i32::MAX)); + let err = scalar + .cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .expect_err("expected cast to fail"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } + #[test] fn null_dictionary_scalar_produces_null_dictionary_array() { let dictionary_scalar = ScalarValue::Dictionary( @@ -9047,6 +9206,27 @@ mod tests { } } + #[test] + fn test_views_minimize_memory() { + let value = "this string is longer than 12 bytes".to_string(); + + let scalar = ScalarValue::Utf8View(Some(value.clone())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_string_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + // Ensure we only have a single copy of the value string + assert_eq!(value.len(), buffers[0].len()); + + // Same but for BinaryView + let scalar = ScalarValue::BinaryView(Some(value.bytes().collect())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_binary_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + assert_eq!(value.len(), buffers[0].len()); + } + #[test] fn test_convert_array_to_scalar_vec() { // 1: Regular ListArray diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index da298c20ebcb4..ba13ef392d912 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -283,9 +283,13 @@ impl From> for Precision { /// and the transformations output are not always predictable. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Statistics { - /// The number of table rows. + /// The number of rows estimated to be scanned. pub num_rows: Precision, - /// Total bytes of the table rows. + /// The total bytes of the output data. + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. pub total_byte_size: Precision, /// Statistics on a column level. /// @@ -317,6 +321,31 @@ impl Statistics { } } + /// Calculates `total_byte_size` based on the schema and `num_rows`. + /// If any of the columns has non-primitive width, `total_byte_size` is set to inexact. + pub fn calculate_total_byte_size(&mut self, schema: &Schema) { + let mut row_size = Some(0); + for field in schema.fields() { + match field.data_type().primitive_width() { + Some(width) => { + row_size = row_size.map(|s| s + width); + } + None => { + row_size = None; + break; + } + } + } + match row_size { + None => { + self.total_byte_size = self.total_byte_size.to_inexact(); + } + Some(size) => { + self.total_byte_size = self.num_rows.multiply(&Precision::Exact(size)); + } + } + } + /// Returns an unbounded `ColumnStatistics` for each field in the schema. pub fn unknown_column(schema: &Schema) -> Vec { schema @@ -367,7 +396,7 @@ impl Statistics { return self; }; - #[allow(clippy::large_enum_variant)] + #[expect(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -477,15 +506,38 @@ impl Statistics { self.column_statistics = self .column_statistics .into_iter() - .map(ColumnStatistics::to_inexact) + .map(|cs| { + let mut cs = cs.to_inexact(); + // Scale byte_size by the row ratio + cs.byte_size = match cs.byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + }; + cs + }) .collect(); - // Adjust the total_byte_size for the ratio of rows before and after, also marking it as inexact - self.total_byte_size = match &self.total_byte_size { - Precision::Exact(n) | Precision::Inexact(n) => { - let adjusted = (*n as f64 * ratio) as usize; - Precision::Inexact(adjusted) + + // Compute total_byte_size as sum of column byte_size values if all are present, + // otherwise fall back to scaling the original total_byte_size + let sum_scan_bytes: Option = self + .column_statistics + .iter() + .map(|cs| cs.byte_size.get_value().copied()) + .try_fold(0usize, |acc, val| val.map(|v| acc + v)); + + self.total_byte_size = match sum_scan_bytes { + Some(sum) => Precision::Inexact(sum), + None => { + // Fall back to scaling original total_byte_size if not all columns have byte_size + match &self.total_byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((*n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + } } - Precision::Absent => Precision::Absent, }; Ok(self) } @@ -581,6 +633,7 @@ impl Statistics { col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); col_stats.distinct_count = Precision::Absent; + col_stats.byte_size = col_stats.byte_size.add(&item_col_stats.byte_size); } Ok(Statistics { @@ -642,6 +695,11 @@ impl Display for Statistics { } else { s }; + let s = if cs.byte_size != Precision::Absent { + format!("{} ScanBytes={}", s, cs.byte_size) + } else { + s + }; s + ")" }) @@ -671,6 +729,21 @@ pub struct ColumnStatistics { pub sum_value: Precision, /// Number of distinct values pub distinct_count: Precision, + /// Estimated size of this column's data in bytes for the output. + /// + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. + /// + /// Currently this is accurately calculated for primitive types only. + /// For complex types (like Utf8, List, Struct, etc), this value may be + /// absent or inexact (e.g. estimated from the size of the data in the source Parquet files). + /// + /// This value is automatically scaled when operations like limits or + /// filters reduce the number of rows (see [`Statistics::with_fetch`]). + pub byte_size: Precision, } impl ColumnStatistics { @@ -693,6 +766,7 @@ impl ColumnStatistics { min_value: Precision::Absent, sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -726,6 +800,13 @@ impl ColumnStatistics { self } + /// Set the scan byte size + /// This should initially be set to the total size of the column. + pub fn with_byte_size(mut self, byte_size: Precision) -> Self { + self.byte_size = byte_size; + self + } + /// If the exactness of a [`ColumnStatistics`] instance is lost, this /// function relaxes the exactness of all information by converting them /// [`Precision::Inexact`]. @@ -735,6 +816,7 @@ impl ColumnStatistics { self.min_value = self.min_value.to_inexact(); self.sum_value = self.sum_value.to_inexact(); self.distinct_count = self.distinct_count.to_inexact(); + self.byte_size = self.byte_size.to_inexact(); self } } @@ -961,9 +1043,11 @@ mod tests { Precision::Exact(ScalarValue::Int64(None)), ); // Overflow returns error - assert!(Precision::Exact(ScalarValue::Int32(Some(256))) - .cast_to(&DataType::Int8) - .is_err()); + assert!( + Precision::Exact(ScalarValue::Int32(Some(256))) + .cast_to(&DataType::Int8) + .is_err() + ); } #[test] @@ -976,8 +1060,6 @@ mod tests { // Precision is not copy (requires .clone()) let precision: Precision = Precision::Exact(ScalarValue::Int64(Some(42))); - // Clippy would complain about this if it were Copy - #[allow(clippy::redundant_clone)] let p2 = precision.clone(); assert_eq!(precision, p2); } @@ -1026,6 +1108,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(64))), sum_value: Precision::Exact(ScalarValue::Int64(Some(4600))), distinct_count: Precision::Exact(100), + byte_size: Precision::Exact(800), } } @@ -1048,6 +1131,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ColumnStatistics { null_count: Precision::Exact(2), @@ -1055,6 +1139,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ], }; @@ -1069,6 +1154,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ColumnStatistics { null_count: Precision::Exact(3), @@ -1076,6 +1162,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(5))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ], }; @@ -1139,6 +1226,7 @@ mod tests { min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }], }; @@ -1151,6 +1239,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), }], }; @@ -1215,7 +1304,10 @@ mod tests { let items = vec![stats1, stats2]; let e = Statistics::try_merge_iter(&items, &schema).unwrap_err(); - assert_contains!(e.to_string(), "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1"); + assert_contains!( + e.to_string(), + "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1" + ); } #[test] @@ -1277,6 +1369,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(0))), sum_value: Precision::Exact(ScalarValue::Int32(Some(5050))), distinct_count: Precision::Exact(50), + byte_size: Precision::Exact(4000), }, ColumnStatistics { null_count: Precision::Exact(20), @@ -1284,6 +1377,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(10))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10100))), distinct_count: Precision::Exact(75), + byte_size: Precision::Exact(8000), }, ], }; @@ -1294,9 +1388,9 @@ mod tests { // Check num_rows assert_eq!(result.num_rows, Precision::Exact(100)); - // Check total_byte_size is scaled proportionally and marked as inexact - // 100/1000 = 0.1, so 8000 * 0.1 = 800 - assert_eq!(result.total_byte_size, Precision::Inexact(800)); + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.1 = 400, Column 2: 8000 * 0.1 = 800, Sum = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); // Check column statistics are preserved but marked as inexact assert_eq!(result.column_statistics.len(), 2); @@ -1358,6 +1452,7 @@ mod tests { min_value: Precision::Inexact(ScalarValue::Int32(Some(0))), sum_value: Precision::Inexact(ScalarValue::Int32(Some(5050))), distinct_count: Precision::Inexact(50), + byte_size: Precision::Inexact(4000), }], }; @@ -1366,9 +1461,9 @@ mod tests { // Check num_rows is inexact assert_eq!(result.num_rows, Precision::Inexact(500)); - // Check total_byte_size is scaled and inexact - // 500/1000 = 0.5, so 8000 * 0.5 = 4000 - assert_eq!(result.total_byte_size, Precision::Inexact(4000)); + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.5 = 2000, Sum = 2000 + assert_eq!(result.total_byte_size, Precision::Inexact(2000)); // Column stats remain inexact assert_eq!( @@ -1425,8 +1520,8 @@ mod tests { .unwrap(); assert_eq!(result.num_rows, Precision::Exact(300)); - // 300/1000 = 0.3, so 8000 * 0.3 = 2400 - assert_eq!(result.total_byte_size, Precision::Inexact(2400)); + // Column 1: byte_size 800 * (300/500) = 240, Sum = 240 + assert_eq!(result.total_byte_size, Precision::Inexact(240)); } #[test] @@ -1442,8 +1537,8 @@ mod tests { let result = original_stats.clone().with_fetch(Some(100), 0, 4).unwrap(); assert_eq!(result.num_rows, Precision::Exact(400)); - // 400/1000 = 0.4, so 8000 * 0.4 = 3200 - assert_eq!(result.total_byte_size, Precision::Inexact(3200)); + // Column 1: byte_size 800 * 0.4 = 320, Sum = 320 + assert_eq!(result.total_byte_size, Precision::Inexact(320)); } #[test] @@ -1458,6 +1553,7 @@ mod tests { min_value: Precision::Absent, sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }], }; @@ -1496,6 +1592,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-100))), sum_value: Precision::Exact(ScalarValue::Int32(Some(123456))), distinct_count: Precision::Exact(789), + byte_size: Precision::Exact(4000), }; let original_stats = Statistics { @@ -1524,4 +1621,140 @@ mod tests { ); assert_eq!(result_col_stats.distinct_count, Precision::Inexact(789)); } + + #[test] + fn test_byte_size_try_merge() { + // Test that byte_size is summed correctly in try_merge + let col_stats1 = ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(1000), + }; + let col_stats2 = ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(2000), + }; + + let stats1 = Statistics { + num_rows: Precision::Exact(50), + total_byte_size: Precision::Exact(1000), + column_statistics: vec![col_stats1], + }; + let stats2 = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(2000), + column_statistics: vec![col_stats2], + }; + + let merged = stats1.try_merge(&stats2).unwrap(); + assert_eq!( + merged.column_statistics[0].byte_size, + Precision::Exact(3000) // 1000 + 2000 + ); + } + + #[test] + fn test_byte_size_to_inexact() { + let col_stats = ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(5000), + }; + + let inexact = col_stats.to_inexact(); + assert_eq!(inexact.byte_size, Precision::Inexact(5000)); + } + + #[test] + fn test_with_byte_size_builder() { + let col_stats = + ColumnStatistics::new_unknown().with_byte_size(Precision::Exact(8192)); + assert_eq!(col_stats.byte_size, Precision::Exact(8192)); + } + + #[test] + fn test_with_fetch_scales_byte_size() { + // Test that byte_size is scaled by the row ratio in with_fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8000), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // byte_size should be scaled: 4000 * 0.1 = 400, 8000 * 0.1 = 800 + assert_eq!( + result.column_statistics[0].byte_size, + Precision::Inexact(400) + ); + assert_eq!( + result.column_statistics[1].byte_size, + Precision::Inexact(800) + ); + + // total_byte_size should be computed as sum of byte_size values: 400 + 800 = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); + } + + #[test] + fn test_with_fetch_total_byte_size_fallback() { + // Test that total_byte_size falls back to scaling when not all columns have byte_size + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, // One column has no byte_size + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // total_byte_size should fall back to scaling: 8000 * 0.1 = 800 + assert_eq!(result.total_byte_size, Precision::Inexact(800)); + } } diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index c51dea1c4de04..f060704944233 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -735,32 +735,34 @@ mod tests { let non_existing = cwd.join("non-existing-dir").display().to_string(); let non_existing_str = non_existing.as_str(); - env::set_var(udf_env, non_existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_err()); - - env::set_var(udf_env, ""); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, " "); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::remove_var(udf_env); - let res = get_data_dir(udf_env, non_existing_str); - assert!(res.is_err()); - - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + unsafe { + env::set_var(udf_env, non_existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_err()); + + env::set_var(udf_env, ""); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, " "); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::remove_var(udf_env); + let res = get_data_dir(udf_env, non_existing_str); + assert!(res.is_err()); + + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + } } #[test] diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 9b36266eec2e9..1e7c02e424256 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -956,12 +956,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeContainer<'a, T> for (C0, C1, C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2) { fn apply_elements Result>( &'a self, @@ -992,13 +992,13 @@ impl< } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - C3: TreeNodeContainer<'a, T>, - > TreeNodeContainer<'a, T> for (C0, C1, C2, C3) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2, C3) { fn apply_elements Result>( &'a self, @@ -1090,12 +1090,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) { fn apply_ref_elements Result>( &self, @@ -1109,13 +1109,13 @@ impl< } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - C3: TreeNodeContainer<'a, T>, - > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) { fn apply_ref_elements Result>( &self, @@ -1336,11 +1336,11 @@ pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; + use crate::Result; use crate::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; - use crate::Result; #[derive(Debug, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index 314529b99a342..dfd2cc4cf2d8b 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::IntervalUnit::*; +use arrow::datatypes::TimeUnit::*; use crate::types::{LogicalTypeRef, NativeType}; use std::sync::{Arc, LazyLock}; @@ -82,3 +83,17 @@ singleton_variant!( Interval, MonthDayNano ); + +singleton_variant!( + LOGICAL_INTERVAL_YEAR_MONTH, + logical_interval_year_month, + Interval, + YearMonth +); + +singleton_variant!( + LOGICAL_DURATION_MICROSECOND, + logical_duration_microsecond, + Duration, + Microsecond +); diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index a1495b779ac97..766c50441613b 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -19,11 +19,11 @@ use super::{ LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, }; -use crate::error::{Result, _internal_err}; +use crate::error::{_internal_err, Result}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, - DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType, + Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::{fmt::Display, sync::Arc}; @@ -241,9 +241,7 @@ impl LogicalType for NativeType { (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), // If given type is Date, return the same type - (Self::Date, origin) if matches!(origin, Date32 | Date64) => { - origin.to_owned() - } + (Self::Date, Date32 | Date64) => origin.to_owned(), (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), @@ -253,6 +251,8 @@ impl LogicalType for NativeType { (Self::Interval(iu), _) => Interval(*iu), (Self::Binary, LargeUtf8) => LargeBinary, (Self::Binary, Utf8View) => BinaryView, + // We don't cast to another kind of binary type if the origin one is already a binary type + (Self::Binary, Binary | LargeBinary | BinaryView) => origin.to_owned(), (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => { BinaryView } @@ -364,7 +364,7 @@ impl LogicalType for NativeType { "Unavailable default cast for native type {} from physical type {}", self, origin - ) + ); } }) } diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index a56b940fab666..78ec434d2b577 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -18,8 +18,10 @@ //! This module provides a function to estimate the memory size of a HashTable prior to allocation use crate::error::_exec_datafusion_err; -use crate::Result; -use std::mem::size_of; +use crate::{HashSet, Result}; +use arrow::array::ArrayData; +use arrow::record_batch::RecordBatch; +use std::{mem::size_of, ptr::NonNull}; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -99,6 +101,74 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result }) } +/// Calculate total used memory of this batch. +/// +/// This function is used to estimate the physical memory usage of the `RecordBatch`. +/// It only counts the memory of large data `Buffer`s, and ignores metadata like +/// types and pointers. +/// The implementation will add up all unique `Buffer`'s memory +/// size, due to: +/// - The data pointer inside `Buffer` are memory regions returned by global memory +/// allocator, those regions can't have overlap. +/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap +/// or reuse the same `Buffer`. For example: taking a slice from `Array`. +/// +/// Example: +/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing +/// to a sub-region of the same buffer. +/// +/// {xxxxxxxxxxxxxxxxxxx} <--- buffer +/// ^ ^ ^ ^ +/// | | | | +/// col1->{ } | | +/// col2--------->{ } +/// +/// In the above case, `get_record_batch_memory_size` will return the size of +/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size. +/// +/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the +/// buffer memory size if multiple arrays within the batch are sharing the same +/// `Buffer`. This method provides temporary fix until the issue is resolved: +/// +pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize { + // Store pointers to `Buffer`'s start memory address (instead of actual + // used data region's pointer represented by current `Array`) + let mut counted_buffers: HashSet> = HashSet::new(); + let mut total_size = 0; + + for array in batch.columns() { + let array_data = array.to_data(); + count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size); + } + + total_size +} + +/// Count the memory usage of `array_data` and its children recursively. +fn count_array_data_memory_size( + array_data: &ArrayData, + counted_buffers: &mut HashSet>, + total_size: &mut usize, +) { + // Count memory usage for `array_data` + for buffer in array_data.buffers() { + if counted_buffers.insert(buffer.data_ptr()) { + *total_size += buffer.capacity(); + } // Otherwise the buffer's memory is already counted + } + + if let Some(null_buffer) = array_data.nulls() + && counted_buffers.insert(null_buffer.inner().inner().data_ptr()) + { + *total_size += null_buffer.inner().inner().capacity(); + } + + // Count all children `ArrayData` recursively + for child in array_data.child_data() { + count_array_data_memory_size(child, counted_buffers, total_size); + } +} + #[cfg(test)] mod tests { use std::{collections::HashSet, mem::size_of}; @@ -132,3 +202,129 @@ mod tests { assert!(estimated.is_err()); } } + +#[cfg(test)] +mod record_batch_tests { + use super::*; + use arrow::array::{Float64Array, Int32Array, ListArray}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use std::sync::Arc; + + #[test] + fn test_get_record_batch_memory_size() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = + Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 60); + } + + #[test] + fn test_get_record_batch_memory_size_with_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = Int32Array::from(vec![None, Some(2), Some(3)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 100); + } + + #[test] + fn test_get_record_batch_memory_size_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "ints", + DataType::Int32, + false, + )])); + + let int_array: Int32Array = Int32Array::from(vec![] as Vec); + let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 0, "Empty batch should have 0 memory size"); + } + + #[test] + fn test_get_record_batch_memory_size_shared_buffer() { + let original = Int32Array::from(vec![1, 2, 3, 4, 5]); + let slice1 = original.slice(0, 3); + let slice2 = original.slice(2, 3); + + let schema_origin = Arc::new(Schema::new(vec![Field::new( + "origin_col", + DataType::Int32, + false, + )])); + let batch_origin = + RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("slice1", DataType::Int32, false), + Field::new("slice2", DataType::Int32, false), + ])); + + let batch_sliced = + RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)]) + .unwrap(); + + let size_origin = get_record_batch_memory_size(&batch_origin); + let size_sliced = get_record_batch_memory_size(&batch_sliced); + + assert_eq!(size_origin, size_sliced); + } + + #[test] + fn test_get_record_batch_memory_size_nested_array() { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "nested_int", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + Field::new( + "nested_int2", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + ])); + + let int_list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]); + + let int_list_array2 = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(4), Some(5), Some(6)]), + ]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_list_array), Arc::new(int_list_array2)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 8208); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21d..e061f852637ca 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,19 +22,20 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::assert_or_internal_err; +use crate::error::{_exec_datafusion_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, - OffsetSizeTrait, + Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, + cast::AsArray, }; use arrow::buffer::OffsetBuffer; -use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::compute::{SortColumn, SortOptions, partition}; use arrow::datatypes::{DataType, Field, SchemaRef}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::HashSet; use std::num::NonZero; use std::ops::Range; @@ -265,10 +266,10 @@ fn needs_quotes(s: &str) -> bool { let mut chars = s.chars(); // first char can not be a number unless escaped - if let Some(first_char) = chars.next() { - if !(first_char.is_ascii_lowercase() || first_char == '_') { - return true; - } + if let Some(first_char) = chars.next() + && !(first_char.is_ascii_lowercase() || first_char == '_') + { + return true; } !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') @@ -519,9 +520,7 @@ pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { let arr = arr.into_iter().collect::>(); - if arr.is_empty() { - return _internal_err!("Cannot wrap empty array into list array"); - } + assert_or_internal_err!(!arr.is_empty(), "Cannot wrap empty array into list array"); let lens = arr.iter().map(|x| x.len()).collect::>(); // Assume data type is consistent @@ -944,8 +943,6 @@ mod tests { use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; - use sqlparser::ast::Ident; - use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -1174,7 +1171,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, - span: Span::empty(), + span: sqlparser::tokenizer::Span::empty(), }]; assert_eq!( diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index fb951aa3b0289..fddf834912544 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations +//! [`VecAllocExt`] to help tracking of memory allocations -use hashbrown::{ - hash_table::HashTable, - raw::{Bucket, RawTable}, -}; +use hashbrown::hash_table::HashTable; use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. @@ -114,75 +111,6 @@ impl VecAllocExt for Vec { } } -/// Extension trait for hash browns [`RawTable`] to account for allocations. -pub trait RawTableAllocExt { - /// Item type. - type T; - - /// [Insert](RawTable::insert) new element into table and increase - /// `accounting` by any newly allocated bytes. - /// - /// Returns the bucket where the element was inserted. - /// Note that allocation counts capacity, not size. - /// - /// # Example: - /// ``` - /// # use datafusion_common::utils::proxy::RawTableAllocExt; - /// # use hashbrown::raw::RawTable; - /// let mut table = RawTable::new(); - /// let mut allocated = 0; - /// let hash_fn = |x: &u32| (*x as u64) % 1000; - /// // pretend 0x3117 is the hash value for 1 - /// table.insert_accounted(1, hash_fn, &mut allocated); - /// assert_eq!(allocated, 64); - /// - /// // insert more values - /// for i in 0..100 { - /// table.insert_accounted(i, hash_fn, &mut allocated); - /// } - /// assert_eq!(allocated, 400); - /// ``` - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket; -} - -impl RawTableAllocExt for RawTable { - type T = T; - - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket { - let hash = hasher(&x); - - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(x) => { - // need to request more memory - - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - - self.reserve(bump_elements, hasher); - - // still need to insert the element since first try failed - // Note: cannot use `.expect` here because `T` may not implement `Debug` - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(_) => panic!("just grew the container"), - } - } - } - } -} - /// Extension trait for hash browns [`HashTable`] to account for allocations. pub trait HashTableAllocExt { /// Item type. diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 67a73ac6f6693..bd88ed3b9ca1e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -46,7 +46,7 @@ array_expressions = ["nested_expressions"] avro = ["datafusion-common/avro", "datafusion-datasource-avro"] backtrace = ["datafusion-common/backtrace"] compression = [ - "xz2", + "liblzma", "bzip2", "flate2", "zstd", @@ -79,7 +79,6 @@ parquet_encryption = [ "datafusion-common/parquet_encryption", "datafusion-datasource-parquet/parquet_encryption", ] -pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ "datafusion-functions/regex_expressions", ] @@ -88,6 +87,7 @@ recursive_protection = [ "datafusion-expr/recursive_protection", "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", + "datafusion-physical-expr/recursive_protection", "datafusion-sql/recursive_protection", "sqlparser/recursive-protection", ] @@ -115,7 +115,7 @@ arrow = { workspace = true } arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.6.1", optional = true } +bzip2 = { workspace = true, optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } @@ -143,24 +143,23 @@ datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } datafusion-sql = { workspace = true, optional = true } -flate2 = { version = "1.1.4", optional = true } +flate2 = { workspace = true, optional = true } futures = { workspace = true } itertools = { workspace = true } +liblzma = { workspace = true, optional = true } log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } rand = { workspace = true } regex = { workspace = true } -rstest = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.18", features = ["v4", "js"] } -xz2 = { version = "0.1", optional = true, features = ["static"] } -zstd = { version = "0.13", optional = true, default-features = false } +uuid = { version = "1.19", features = ["v4", "js"] } +zstd = { workspace = true, optional = true } [dev-dependencies] async-trait = { workspace = true } @@ -173,9 +172,9 @@ datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } -glob = { version = "0.3.0" } +glob = { workspace = true } insta = { workspace = true } -paste = "^1.0" +paste = { workspace = true } rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.5" regex = { workspace = true } @@ -240,6 +239,10 @@ harness = false name = "parquet_query_sql" required-features = ["parquet"] +[[bench]] +harness = false +name = "range_and_generate_series" + [[bench]] harness = false name = "sql_planner" @@ -272,3 +275,8 @@ name = "dataframe" [[bench]] harness = false name = "spm" + +[[bench]] +harness = false +name = "preserve_file_partitioning" +required-features = ["parquet"] diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 87aeed49337eb..4aa667504e459 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -31,6 +31,7 @@ use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs index de0f0d8250572..228457947fd5a 100644 --- a/datafusion/core/benches/csv_load.rs +++ b/datafusion/core/benches/csv_load.rs @@ -34,6 +34,7 @@ use std::time::Duration; use test_utils::AccessLogGenerator; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn load_csv( ctx: Arc>, rt: &Runtime, diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index fffe2e2d17522..630bc056600b4 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -18,9 +18,9 @@ //! This module provides the in-memory table for more realistic benchmarking. use arrow::array::{ - builder::{Int64Builder, StringBuilder}, ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder, UInt64Array, + builder::{Int64Builder, StringBuilder}, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; @@ -139,6 +139,7 @@ fn create_record_batch( /// Create record batches of `partitions_len` partitions and `batch_size` for each batch, /// with a total number of `array_len` records +#[expect(clippy::needless_pass_by_value)] pub fn create_record_batches( schema: SchemaRef, array_len: usize, diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 00fa85918347a..726187ab5e922 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -45,6 +45,7 @@ fn create_context(field_count: u32) -> datafusion_common::Result, rt: &Runtime) { black_box(rt.block_on(async { let mut data_frame = ctx.table("t").await.unwrap(); diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index d05e8b13b2af3..0e638e293d8cf 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -24,16 +24,17 @@ mod data_utils; use crate::criterion::Criterion; use data_utils::{create_table_provider, make_data}; use datafusion::execution::context::SessionContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use datafusion_execution::config::SessionConfig; use parking_lot::Mutex; use std::hint::black_box; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); @@ -124,6 +125,7 @@ async fn distinct_with_limit( Ok(()) } +#[expect(clippy::needless_pass_by_value)] fn run(rt: &Runtime, plan: Arc, ctx: Arc) { black_box(rt.block_on(distinct_with_limit(plan.clone(), ctx.clone()))).unwrap(); } diff --git a/datafusion/core/benches/filter_query_sql.rs b/datafusion/core/benches/filter_query_sql.rs index 16905e0f96605..3b80518d32dcd 100644 --- a/datafusion/core/benches/filter_query_sql.rs +++ b/datafusion/core/benches/filter_query_sql.rs @@ -20,7 +20,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 09234546b2dfe..67904197bc257 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::hint::black_box; use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, RecordBatch}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; -use rand::prelude::ThreadRng; use rand::Rng; +use rand::prelude::ThreadRng; use tokio::runtime::Runtime; use datafusion::prelude::SessionContext; @@ -33,11 +34,12 @@ use datafusion_functions_nested::map::map; mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { - let mut keys = vec![]; - for _ in 0..1000 { - keys.push(rng.random_range(0..9999).to_string()); + let mut keys = HashSet::with_capacity(1000); + while keys.len() < 1000 { + let key = rng.random_range(0..9999).to_string(); + keys.insert(key); } - keys + keys.into_iter().collect() } fn build_values(rng: &mut ThreadRng) -> Vec { diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 76824850c114c..4d1d4abb6783c 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -36,6 +36,7 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index e2b3810480130..e44524127bf18 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -23,14 +23,14 @@ use arrow::datatypes::{ SchemaRef, }; use arrow::record_batch::RecordBatch; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distr::uniform::SampleUniform; use rand::distr::Alphanumeric; +use rand::distr::uniform::SampleUniform; use rand::prelude::*; use rand::rng; use std::fs::File; diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index e4838572f60fb..e6763b4761c2a 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, }; use datafusion::prelude::SessionContext; use datafusion_datasource::memory::MemorySourceConfig; @@ -40,6 +40,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; // Initialize the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. +#[expect(clippy::needless_pass_by_value)] fn sort_preserving_merge_operator( session_ctx: Arc, rt: &Runtime, diff --git a/datafusion/core/benches/preserve_file_partitioning.rs b/datafusion/core/benches/preserve_file_partitioning.rs new file mode 100644 index 0000000000000..17ebca52cd1d2 --- /dev/null +++ b/datafusion/core/benches/preserve_file_partitioning.rs @@ -0,0 +1,838 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `preserve_file_partitions` optimization. +//! +//! When enabled, this optimization declares Hive-partitioned tables as +//! `Hash([partition_col])` partitioned, allowing the query optimizer to +//! skip unnecessary repartitioning and sorting operations. +//! +//! When This Optimization Helps +//! - Window functions: PARTITION BY on partition column eliminates RepartitionExec and SortExec +//! - Aggregates with ORDER BY: GROUP BY partition column and ORDER BY eliminates post aggregate sort +//! +//! When This Optimization Does NOT Help +//! - GROUP BY non-partition columns: Required Hash distribution doesn't match declared partitioning +//! - When the number of distinct file partitioning groups < the number of CPUs available: Reduces +//! parallelization, thus may outweigh the pros of reduced shuffles +//! +//! Usage +//! - BENCH_SIZE=small|medium|large cargo bench -p datafusion --bench preserve_file_partitions +//! - SAVE_PLANS=1 cargo bench ... # Save query plans to files + +use arrow::array::{ArrayRef, Float64Array, StringArray, TimestampMillisecondArray}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext, col}; +use datafusion_expr::SortExpr; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use std::fs::{self, File}; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Copy)] +struct BenchConfig { + fact_partitions: usize, + rows_per_partition: usize, + target_partitions: usize, + measurement_time_secs: u64, +} + +impl BenchConfig { + fn small() -> Self { + Self { + fact_partitions: 10, + rows_per_partition: 1_000_000, + target_partitions: 10, + measurement_time_secs: 15, + } + } + + fn medium() -> Self { + Self { + fact_partitions: 30, + rows_per_partition: 3_000_000, + target_partitions: 30, + measurement_time_secs: 30, + } + } + + fn large() -> Self { + Self { + fact_partitions: 50, + rows_per_partition: 6_000_000, + target_partitions: 50, + measurement_time_secs: 90, + } + } + + fn from_env() -> Self { + match std::env::var("BENCH_SIZE").as_deref() { + Ok("small") | Ok("SMALL") => Self::small(), + Ok("medium") | Ok("MEDIUM") => Self::medium(), + Ok("large") | Ok("LARGE") => Self::large(), + _ => { + println!("Using SMALL dataset (set BENCH_SIZE=small|medium|large)"); + Self::small() + } + } + } + + fn total_rows(&self) -> usize { + self.fact_partitions * self.rows_per_partition + } + + fn high_cardinality(base: &Self) -> Self { + Self { + fact_partitions: (base.fact_partitions as f64 * 2.5) as usize, + rows_per_partition: base.rows_per_partition / 2, + target_partitions: base.target_partitions, + measurement_time_secs: base.measurement_time_secs, + } + } +} + +fn dkey_names(count: usize) -> Vec { + (0..count) + .map(|i| { + if i < 26 { + ((b'A' + i as u8) as char).to_string() + } else { + format!( + "{}{}", + (b'A' + ((i / 26) - 1) as u8) as char, + (b'A' + (i % 26) as u8) as char + ) + } + }) + .collect() +} + +/// Hive-partitioned fact table, sorted by timestamp within each partition. +fn generate_fact_table( + base_dir: &Path, + num_partitions: usize, + rows_per_partition: usize, +) { + let fact_dir = base_dir.join("fact"); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let dkeys = dkey_names(num_partitions); + + for dkey in &dkeys { + let part_dir = fact_dir.join(format!("f_dkey={dkey}")); + fs::create_dir_all(&part_dir).unwrap(); + let file_path = part_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + + let mut writer = + ArrowWriter::try_new(file, schema.clone(), Some(props.clone())).unwrap(); + + let base_ts = 1672567200000i64; // 2023-01-01T09:00:00 + let timestamps: Vec = (0..rows_per_partition) + .map(|i| base_ts + (i as i64 * 10000)) + .collect(); + + let values: Vec = (0..rows_per_partition) + .map(|i| 50.0 + (i % 100) as f64 + ((i % 7) as f64 * 10.0)) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(TimestampMillisecondArray::from(timestamps)) as ArrayRef, + Arc::new(Float64Array::from(values)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } +} + +/// Single-file dimension table for CollectLeft joins. +fn generate_dimension_table(base_dir: &Path, num_partitions: usize) { + let dim_dir = base_dir.join("dimension"); + fs::create_dir_all(&dim_dir).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let file_path = dim_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap(); + + let dkeys = dkey_names(num_partitions); + let envs = ["dev", "prod", "staging", "test"]; + let services = ["log", "trace", "metric"]; + let hosts = ["ma", "vim", "nano", "emacs"]; + + let d_dkey_vals: Vec = dkeys.clone(); + let env_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| envs[i % envs.len()].to_string()) + .collect(); + let service_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| services[i % services.len()].to_string()) + .collect(); + let host_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| hosts[i % hosts.len()].to_string()) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(d_dkey_vals)) as ArrayRef, + Arc::new(StringArray::from(env_vals)), + Arc::new(StringArray::from(service_vals)), + Arc::new(StringArray::from(host_vals)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); +} + +struct BenchVariant { + name: &'static str, + preserve_file_partitions: usize, + prefer_existing_sort: bool, +} + +const BENCH_VARIANTS: [BenchVariant; 3] = [ + BenchVariant { + name: "with_optimization", + preserve_file_partitions: 1, + prefer_existing_sort: false, + }, + BenchVariant { + name: "prefer_existing_sort", + preserve_file_partitions: 0, + prefer_existing_sort: true, + }, + BenchVariant { + name: "without_optimization", + preserve_file_partitions: 0, + prefer_existing_sort: false, + }, +]; + +async fn save_plans( + output_file: &Path, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: Option>>, +) { + let mut file = File::create(output_file).unwrap(); + writeln!(file, "Query: {query}\n").unwrap(); + + for variant in &BENCH_VARIANTS { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + variant.preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + variant.prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![("f_dkey".to_string(), DataType::Utf8)], + ..Default::default() + }; + if let Some(ref order) = file_sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", fact_path, fact_options) + .await + .unwrap(); + + if let Some(dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(query).await.unwrap(); + let plan = df.explain(false, false).unwrap().collect().await.unwrap(); + writeln!(file, "=== {} ===", variant.name).unwrap(); + writeln!(file, "{}\n", pretty_format_batches(&plan).unwrap()).unwrap(); + } +} + +#[allow(clippy::too_many_arguments)] +fn run_benchmark( + c: &mut Criterion, + rt: &Runtime, + name: &str, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: &Option>>, +) { + if std::env::var("SAVE_PLANS").is_ok() { + let output_path = format!("{name}_plans.txt"); + rt.block_on(save_plans( + Path::new(&output_path), + fact_path, + dim_path, + target_partitions, + query, + file_sort_order.clone(), + )); + println!("Plans saved to {output_path}"); + } + + let mut group = c.benchmark_group(name); + + for variant in &BENCH_VARIANTS { + let fact_path_owned = fact_path.to_string(); + let dim_path_owned = dim_path.map(|s| s.to_string()); + let sort_order = file_sort_order.clone(); + let query_owned = query.to_string(); + let preserve_file_partitions = variant.preserve_file_partitions; + let prefer_existing_sort = variant.prefer_existing_sort; + + group.bench_function(variant.name, |b| { + b.to_async(rt).iter(|| { + let fact_path = fact_path_owned.clone(); + let dim_path = dim_path_owned.clone(); + let sort_order = sort_order.clone(); + let query = query_owned.clone(); + async move { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![( + "f_dkey".to_string(), + DataType::Utf8, + )], + ..Default::default() + }; + if let Some(ref order) = sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", &fact_path, fact_options) + .await + .unwrap(); + + if let Some(ref dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(&query).await.unwrap(); + df.collect().await.unwrap() + } + }) + }); + } + + group.finish(); +} + +/// Aggregate on high-cardinality partitions which eliminates repartition and sort. +/// +/// Query: SELECT f_dkey, COUNT(*), SUM(value) FROM fact GROUP BY f_dkey ORDER BY f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ No repartitioning needed │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]) │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ FinalPartitioned │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey ASC] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, COUNT(*) as cnt, SUM(value) as total \ + FROM fact \ + GROUP BY f_dkey \ + ORDER BY f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order", + hc_fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Join and aggregate on partition column which demonstrates propagation through join. +/// +/// Query: SELECT f.f_dkey, MAX(d.env), ... FROM fact f JOIN dimension d ON f.f_dkey = d.d_dkey +/// WHERE d.service = 'log' GROUP BY f.f_dkey ORDER BY f.f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ Hash partitioning propagates through join │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ Hash partitioning preserved on probe side │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact, N groups) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ preserve_order=true │ Uses k-way merge to maintain sort, has overhead │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_join_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + dim_path: &str, + target_partitions: usize, +) { + let query = "SELECT f.f_dkey, MAX(d.env), MAX(d.service), COUNT(*), SUM(f.value) \ + FROM fact f \ + INNER JOIN dimension d ON f.f_dkey = d.d_dkey \ + WHERE d.service = 'log' \ + GROUP BY f.f_dkey \ + ORDER BY f.f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order_join", + hc_fact_path, + Some(dim_path), + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Window function with LIMIT which demonstrates partition and sort elimination. +/// +/// Query: SELECT f_dkey, timestamp, value, +/// ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn +/// FROM fact LIMIT 1000 +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ │ (LIMIT 1000) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ No repaartition needed │ +/// │ │ PARTITION BY f_dkey │ │ +/// │ │ ORDER BY timestamp │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey, timestamp] │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey, timestamp] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_window_bench( + c: &mut Criterion, + rt: &Runtime, + fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, timestamp, value, \ + ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn \ + FROM fact \ + LIMIT 1000"; + + let file_sort_order = vec![vec![ + col("f_dkey").sort(true, false), + col("timestamp").sort(true, false), + ]]; + + run_benchmark( + c, + rt, + "preserve_order_window", + fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +fn benchmark_main(c: &mut Criterion) { + let config = BenchConfig::from_env(); + let hc_config = BenchConfig::high_cardinality(&config); + + println!("\n=== Preserve File Partitioning Benchmark ==="); + println!( + "Normal config: {} partitions × {} rows = {} total rows", + config.fact_partitions, + config.rows_per_partition, + config.total_rows() + ); + println!( + "High-cardinality config: {} partitions × {} rows = {} total rows", + hc_config.fact_partitions, + hc_config.rows_per_partition, + hc_config.total_rows() + ); + println!("Target partitions: {}\n", config.target_partitions); + + let tmp_dir = TempDir::new().unwrap(); + println!("Generating data..."); + + // High-cardinality fact table + generate_fact_table( + tmp_dir.path(), + hc_config.fact_partitions, + hc_config.rows_per_partition, + ); + let hc_fact_dir = tmp_dir.path().join("fact_hc"); + fs::rename(tmp_dir.path().join("fact"), &hc_fact_dir).unwrap(); + let hc_fact_path = hc_fact_dir.to_str().unwrap().to_string(); + + // Normal fact table + generate_fact_table( + tmp_dir.path(), + config.fact_partitions, + config.rows_per_partition, + ); + let fact_path = tmp_dir.path().join("fact").to_str().unwrap().to_string(); + + // Dimension table (for join) + generate_dimension_table(tmp_dir.path(), hc_config.fact_partitions); + let dim_path = tmp_dir + .path() + .join("dimension") + .to_str() + .unwrap() + .to_string(); + + println!("Done.\n"); + + let rt = Runtime::new().unwrap(); + + preserve_order_bench(c, &rt, &hc_fact_path, hc_config.target_partitions); + preserve_order_join_bench( + c, + &rt, + &hc_fact_path, + &dim_path, + hc_config.target_partitions, + ); + preserve_order_window_bench(c, &rt, &fact_path, config.target_partitions); +} + +criterion_group! { + name = benches; + config = { + let config = BenchConfig::from_env(); + Criterion::default() + .measurement_time(std::time::Duration::from_secs(config.measurement_time_secs)) + .sample_size(10) + }; + targets = benchmark_main +} +criterion_main!(benches); diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs index 139fb12c30947..3c2199c708de6 100644 --- a/datafusion/core/benches/push_down_filter.rs +++ b/datafusion/core/benches/push_down_filter.rs @@ -18,16 +18,16 @@ use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use bytes::{BufMut, BytesMut}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::config::ConfigOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; +use object_store::ObjectStore; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::ObjectStore; use parquet::arrow::ArrowWriter; use std::sync::Arc; diff --git a/datafusion/core/benches/range_and_generate_series.rs b/datafusion/core/benches/range_and_generate_series.rs new file mode 100644 index 0000000000000..2b1463a21062a --- /dev/null +++ b/datafusion/core/benches/range_and_generate_series.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate datafusion; + +mod data_utils; + +use crate::criterion::Criterion; +use datafusion::execution::context::SessionContext; +use parking_lot::Mutex; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +#[expect(clippy::needless_pass_by_value)] +fn query(ctx: Arc>, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context() -> Arc> { + let ctx = SessionContext::new(); + Arc::new(Mutex::new(ctx)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(); + let rt = Runtime::new().unwrap(); + + c.bench_function("range(1000000)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000)")) + }); + + c.bench_function("generate_series(1000000)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000)", + ) + }) + }); + + c.bench_function("range(0, 1000000, 5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(0, 1000000, 5)")) + }); + + c.bench_function("generate_series(0, 1000000, 5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(0, 1000000, 5)", + ) + }) + }); + + c.bench_function("range(1000000, 0, -5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000, 0, -5)")) + }); + + c.bench_function("generate_series(1000000, 0, -5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000, 0, -5)", + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 540f7212e96e9..d06ed3f28b743 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::scalar::ScalarValue; fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 276151e253f7e..4ba57a1530e81 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -78,18 +78,18 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ execution::context::TaskContext, physical_plan::{ + ExecutionPlan, ExecutionPlanProperties, coalesce_partitions::CoalescePartitionsExec, - sorts::sort_preserving_merge::SortPreservingMergeExec, ExecutionPlan, - ExecutionPlanProperties, + sorts::sort_preserving_merge::SortPreservingMergeExec, }, prelude::SessionContext, }; use datafusion_datasource::memory::MemorySourceConfig; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; /// Benchmarks for SortPreservingMerge stream -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use futures::StreamExt; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -355,14 +355,14 @@ fn utf8_high_cardinality_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_low, utf8_low, utf8_high) fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -388,14 +388,14 @@ fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_view_low, utf8_view_low, utf8_view_high) fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -421,15 +421,15 @@ fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (f64, utf8_low, utf8_low, i64) fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -459,15 +459,15 @@ fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (f64, utf8_view_low, utf8_view_low, i64) fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -497,8 +497,8 @@ fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict) fn dictionary_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut values = gen.utf8_low_cardinality_values(); + let mut data_gen = DataGenerator::new(); + let mut values = data_gen.utf8_low_cardinality_values(); if sorted { values.sort_unstable(); } @@ -512,12 +512,12 @@ fn dictionary_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict, utf8_dict, utf8_dict) fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen + let mut data_gen = DataGenerator::new(); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) .collect(); if sorted { @@ -543,13 +543,13 @@ fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict, utf8_dict, utf8_dict, i64) fn mixed_dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen + let mut data_gen = DataGenerator::new(); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index e535a018161f1..c18070fb7725e 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -37,6 +37,7 @@ use datafusion::execution::context::SessionContext; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); @@ -97,8 +98,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().first().unwrap().clone(); - ctx + ctx_holder.lock().first().unwrap().clone() } fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index ecc3f908d4b15..9db1306d2bd19 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -20,13 +20,13 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::{collect, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect}; use criterion::async_executor::FuturesExecutor; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_datasource::memory::MemorySourceConfig; fn generate_spm_for_round_robin_tie_breaker( diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 6266a7184cf51..7cce7e0bd7db7 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -23,19 +23,23 @@ extern crate datafusion; mod data_utils; use crate::criterion::Criterion; +use arrow::array::PrimitiveArray; use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::ArrowNativeTypeOp; +use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{DataType, Field, Fields, Schema}; use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; -use datafusion_common::{config::Dialect, ScalarValue}; +use datafusion_common::{ScalarValue, config::Dialect}; use datafusion_expr::col; +use rand_distr::num_traits::NumCast; use std::hint::black_box; use std::path::PathBuf; use std::sync::Arc; +use test_utils::TableDef; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; -use test_utils::TableDef; use tokio::runtime::Runtime; const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; @@ -89,6 +93,7 @@ fn create_context() -> SessionContext { /// Register the table definitions as a MemTable with the context and return the /// context +#[expect(clippy::needless_pass_by_value)] fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { defs.iter().for_each(|TableDef { name, schema }| { ctx.register_table( @@ -155,18 +160,30 @@ fn benchmark_with_param_values_many_columns( /// 0,100...9900 /// 0,200...19800 /// 0,300...29700 -fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows: usize) { - // ("c0", [0, 0, ...]) - // ("c1": [100, 200, ...]) - // etc - let iter = (0..num_columns).map(|i| i as u64).map(|i| { - let array: ArrayRef = Arc::new(arrow::array::UInt64Array::from_iter_values( - (0..num_rows) - .map(|j| j as u64 * 100 + i) - .collect::>(), - )); +fn register_union_order_table_generic( + ctx: &SessionContext, + num_columns: usize, + num_rows: usize, +) where + T: ArrowPrimitiveType, + T::Native: ArrowNativeTypeOp + NumCast, +{ + let iter = (0..num_columns).map(|i| { + let array_data: Vec = (0..num_rows) + .map(|j| { + let value = (j as u64) * 100 + (i as u64); + ::from(value).unwrap_or_else(|| { + panic!("Failed to cast numeric value to Native type") + }) + }) + .collect(); + + // Use PrimitiveArray which is generic over the ArrowPrimitiveType T + let array: ArrayRef = Arc::new(PrimitiveArray::::from_iter_values(array_data)); + (format!("c{i}"), array) }); + let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); let partitions = vec![vec![batch]]; @@ -183,7 +200,6 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows ctx.register_table("t", Arc::new(table)).unwrap(); } - /// return a query like /// ```sql /// select c1, 2 as c2, ... n as cn from t ORDER BY c1 @@ -226,8 +242,10 @@ fn criterion_benchmark(c: &mut Criterion) { if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() { - panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ - 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + panic!( + "benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark" + ) } let ctx = create_context(); @@ -403,13 +421,40 @@ fn criterion_benchmark(c: &mut Criterion) { // -- Sorted Queries -- // 100, 200 && 300 is taking too long - https://github.com/apache/datafusion/issues/18366 + // Logical Plan for datatype Int64 and UInt64 differs, UInt64 Logical Plan's Union are wrapped + // up in Projection, and EliminateNestedUnion OptimezerRule is not applied leading to significantly + // longer execution time. + // https://github.com/apache/datafusion/issues/17261 + for column_count in [10, 50 /* 100, 200, 300 */] { - register_union_order_table(&ctx, column_count, 1000); + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); // this query has many expressions in its sort order so stresses // order equivalence validation c.bench_function( - &format!("physical_sorted_union_order_by_{column_count}"), + &format!("physical_sorted_union_order_by_{column_count}_int64"), + |b| { + // SELECT ... UNION ALL ... + let query = union_orderby_query(column_count); + b.iter(|| physical_plan(&ctx, &rt, &query)) + }, + ); + + let _ = ctx.deregister_table("t"); + } + + for column_count in [10, 50 /* 100, 200, 300 */] { + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); + c.bench_function( + &format!("physical_sorted_union_order_by_{column_count}_uint64"), |b| { // SELECT ... UNION ALL ... let query = union_orderby_query(column_count); @@ -477,9 +522,6 @@ fn criterion_benchmark(c: &mut Criterion) { }; let raw_tpcds_sql_queries = (1..100) - // skip query 75 until it is fixed - // https://github.com/apache/datafusion/issues/17801 - .filter(|q| *q != 75) .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index aff7cb4d101d5..adaf3e5911e9b 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -18,7 +18,7 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow_schema::DataType; use arrow_schema::TimeUnit::Nanosecond; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58797dfed6b67..0c188f7ba1047 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -20,7 +20,7 @@ use std::{fmt::Write, sync::Arc, time::Duration}; use arrow::array::{Int64Builder, RecordBatch, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion::{ datasource::{ file_format::parquet::ParquetFormat, @@ -31,13 +31,13 @@ use datafusion::{ use datafusion_execution::runtime_env::RuntimeEnv; use itertools::Itertools; use object_store::{ + ObjectStore, memory::InMemory, path::Path, throttle::{ThrottleConfig, ThrottledStore}, - ObjectStore, }; use parquet::arrow::ArrowWriter; -use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, rngs::StdRng}; use tokio::runtime::Runtime; use url::Url; diff --git a/datafusion/core/benches/struct_query_sql.rs b/datafusion/core/benches/struct_query_sql.rs index 5c7b427310827..96434fc379ea6 100644 --- a/datafusion/core/benches/struct_query_sql.rs +++ b/datafusion/core/benches/struct_query_sql.rs @@ -20,7 +20,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 9a5fb7163be5c..a4ae479de4d27 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -18,13 +18,13 @@ mod data_utils; use arrow::util::pretty::pretty_format_batches; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use datafusion_execution::config::SessionConfig; use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; @@ -46,7 +46,9 @@ async fn create_context( opts.optimizer.enable_topk_aggregation = use_topk; let ctx = SessionContext::new_with_config(cfg); let _ = ctx.register_table("traces", mem_table)?; - let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + let sql = format!( + "select max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};" + ); let df = ctx.sql(sql.as_str()).await?; let physical_plan = df.create_physical_plan().await?; let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); @@ -58,6 +60,7 @@ async fn create_context( Ok((physical_plan, ctx.task_ctx())) } +#[expect(clippy::needless_pass_by_value)] fn run(rt: &Runtime, plan: Arc, ctx: Arc, asc: bool) { black_box(rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await })) .unwrap(); @@ -75,20 +78,20 @@ async fn aggregate( let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); let expected_asc = r#" -+----------------------------------+--------------------------+ -| trace_id | max(traces.timestamp_ms) | -+----------------------------------+--------------------------+ -| 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | -| 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | -| 02801bbe533190a9f8713d75222f445d | 16909009999997 | -| 9e31b3b5a620de32b68fefa5aeea57f1 | 16909009999996 | -| 2d88a860e9bd1cfaa632d8e7caeaa934 | 16909009999995 | -| a47edcef8364ab6f191dd9103e51c171 | 16909009999994 | -| 36a3fa2ccfbf8e00337f0b1254384db6 | 16909009999993 | -| 0756be84f57369012e10de18b57d8a2f | 16909009999992 | -| d4d6bf9845fa5897710e3a8db81d5907 | 16909009999991 | -| 3c2cc1abe728a66b61e14880b53482a0 | 16909009999990 | -+----------------------------------+--------------------------+ ++--------------------------+ +| max(traces.timestamp_ms) | ++--------------------------+ +| 16909009999999 | +| 16909009999998 | +| 16909009999997 | +| 16909009999996 | +| 16909009999995 | +| 16909009999994 | +| 16909009999993 | +| 16909009999992 | +| 16909009999991 | +| 16909009999990 | ++--------------------------+ "# .trim(); if asc { diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index 6d83959f7eb3c..e4643567a0f0c 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -31,6 +31,7 @@ use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); black_box(rt.block_on(df.collect()).unwrap()); diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 63387c023b11a..74a10bf079e61 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -16,10 +16,10 @@ // under the License. use datafusion::execution::SessionStateDefaults; -use datafusion_common::{not_impl_err, HashSet, Result}; +use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; use std::env::args; @@ -108,6 +108,7 @@ fn save_doc_code_text(documentation: &Documentation, name: &str) { file.write_all(attr_text.as_bytes()).unwrap(); } +#[expect(clippy::needless_pass_by_value)] fn print_docs( providers: Vec>, doc_sections: Vec, @@ -254,7 +255,9 @@ fn print_docs( for f in &providers_with_no_docs { eprintln!(" - {f}"); } - not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + not_impl_err!( + "Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}" + ) } else { Ok(docs) } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 98804e424b407..0d060db3bf147 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -26,19 +26,19 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{ - provider_as_source, DefaultTableSource, MemTable, TableProvider, + DefaultTableSource, MemTable, TableProvider, provider_as_source, }; use crate::error::Result; -use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; +use crate::execution::context::{SessionState, TaskContext}; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - LogicalPlanBuilderOptions, Partitioning, TableType, + Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, TableType, col, ident, }; use crate::physical_plan::{ - collect, collect_partitioned, execute_stream, execute_stream_partitioned, - ExecutionPlan, SendableRecordBatchStream, + ExecutionPlan, SendableRecordBatchStream, collect, collect_partitioned, + execute_stream, execute_stream_partitioned, }; use crate::prelude::SessionContext; use std::any::Any; @@ -49,20 +49,20 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, - TableReference, UnnestOptions, + TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, + plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - case, + ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, expr::{Alias, ScalarFunction}, is_null, lit, utils::COUNT_STAR_EXPANSION, - ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ @@ -310,11 +310,20 @@ impl DataFrame { pub fn select_columns(self, columns: &[&str]) -> Result { let fields = columns .iter() - .flat_map(|name| { - self.plan + .map(|name| { + let fields = self + .plan .schema() - .qualified_fields_with_unqualified_name(name) + .qualified_fields_with_unqualified_name(name); + if fields.is_empty() { + Err(unqualified_field_not_found(name, self.plan.schema())) + } else { + Ok(fields) + } }) + .collect::, _>>()? + .into_iter() + .flatten() .collect::>(); let expr: Vec = fields .into_iter() @@ -1655,7 +1664,7 @@ impl DataFrame { pub fn into_view(self) -> Arc { Arc::new(DataFrameTableProvider { plan: self.plan, - table_type: TableType::Temporary, + table_type: TableType::View, }) } @@ -2232,7 +2241,7 @@ impl DataFrame { .schema() .iter() .map(|(qualifier, field)| { - if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { + if qualifier.eq(&qualifier_rename) && field == field_rename { ( col(Column::from((qualifier, field))) .alias_qualified(qualifier.cloned(), new_name), @@ -2321,6 +2330,10 @@ impl DataFrame { /// Cache DataFrame as a memory table. /// + /// Default behavior could be changed using + /// a [`crate::execution::session_state::CacheFactory`] + /// configured via [`SessionState`]. + /// /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -2335,14 +2348,20 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::new_with_state((*self.session_state).clone()); - // The schema is consistent with the output - let plan = self.clone().create_physical_plan().await?; - let schema = plan.schema(); - let task_ctx = Arc::new(self.task_ctx()); - let partitions = collect_partitioned(plan, task_ctx).await?; - let mem_table = MemTable::try_new(schema, partitions)?; - context.read_table(Arc::new(mem_table)) + if let Some(cache_factory) = self.session_state.cache_factory() { + let new_plan = + cache_factory.create(self.plan, self.session_state.as_ref())?; + Ok(Self::new(*self.session_state, new_plan)) + } else { + let context = SessionContext::new_with_state((*self.session_state).clone()); + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; + context.read_table(Arc::new(mem_table)) + } } /// Apply an alias to the DataFrame. @@ -2383,6 +2402,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + #[expect(clippy::needless_pass_by_value)] pub fn fill_null( &self, value: ScalarValue, @@ -2393,7 +2413,7 @@ impl DataFrame { .schema() .fields() .iter() - .map(|f| f.as_ref().clone()) + .map(Arc::clone) .collect() } else { self.find_columns(&columns)? @@ -2430,7 +2450,7 @@ impl DataFrame { } // Helper to find columns from names - fn find_columns(&self, names: &[String]) -> Result> { + fn find_columns(&self, names: &[String]) -> Result> { let schema = self.logical_plan().schema(); names .iter() diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index cb8a6cf29541b..6edf628e2d6d6 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -150,7 +150,7 @@ mod tests { let plan = df.explain(false, false)?.collect().await?; // Filters all the way to Parquet let formatted = pretty::pretty_format_batches(&plan)?.to_string(); - assert!(formatted.contains("FilterExec: id@0 = 1")); + assert!(formatted.contains("FilterExec: id@0 = 1"), "{formatted}"); Ok(()) } diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index 256a11ba693b5..50ee96da3dff0 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -20,9 +20,9 @@ use std::sync::Arc; +use crate::datasource::TableProvider; use crate::datasource::listing::ListingTableConfigExt; use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; -use crate::datasource::TableProvider; use crate::error::Result; use crate::execution::context::SessionState; diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77686c5eb7c27..5aeca92b1626d 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -28,8 +28,8 @@ use datafusion_common::project_schema; use crate::datasource::{TableProvider, TableType}; use crate::error::Result; use crate::logical_expr::Expr; -use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; /// An empty plan that is useful for testing and generating plans /// without mapping them to actual data. diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 3428d08a6ae52..cad35d43db486 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -26,20 +26,21 @@ mod tests { use crate::{ datasource::file_format::test_util::scan_format, prelude::SessionContext, }; - use arrow::array::{as_string_array, Array}; + use arrow::array::{Array, as_string_array}; use datafusion_catalog::Session; use datafusion_common::test_util::batches_to_string; use datafusion_common::{ + Result, cast::{ as_binary_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_microsecond_array, }, - test_util, Result, + test_util, }; use datafusion_datasource_avro::AvroFormat; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use futures::StreamExt; use insta::assert_snapshot; @@ -116,20 +117,20 @@ mod tests { let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); - assert_snapshot!(batches_to_string(&batches),@r###" - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 | - | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00 | - | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00 | - | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00 | - | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00 | - | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 | - | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 | - | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - "###); + assert_snapshot!(batches_to_string(&batches),@r" + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 | + | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00 | + | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00 | + | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00 | + | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00 | + | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 | + | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 | + | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + "); Ok(()) } @@ -245,7 +246,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", + format!("{values:?}") + ); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 52fb8ae904ebf..719bc4361ac91 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -32,12 +32,12 @@ mod tests { use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_catalog::Session; + use datafusion_common::Result; use datafusion_common::cast::as_string_array; use datafusion_common::config::CsvOptions; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_common::test_util::{arrow_test_data, batches_to_string}; - use datafusion_common::Result; use datafusion_datasource::decoder::{ BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; @@ -45,7 +45,7 @@ mod tests { use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::write::BatchSerializer; use datafusion_expr::{col, lit}; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::array::{ Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, @@ -57,8 +57,8 @@ mod tests { use bytes::Bytes; use chrono::DateTime; use datafusion_common::parsers::CompressionTypeVariant; - use futures::stream::BoxStream; use futures::StreamExt; + use futures::stream::BoxStream; use insta::assert_snapshot; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; @@ -621,15 +621,15 @@ mod tests { .collect() .await?; - assert_snapshot!(batches_to_string(&record_batch), @r###" - +----+------+ - | c2 | c3 | - +----+------+ - | 5 | 36 | - | 5 | -31 | - | 5 | -101 | - +----+------+ - "###); + assert_snapshot!(batches_to_string(&record_batch), @r" + +----+------+ + | c2 | c3 | + +----+------+ + | 5 | 36 | + | 5 | -31 | + | 5 | -101 | + +----+------+ + "); Ok(()) } @@ -706,11 +706,11 @@ mod tests { let re = Regex::new(r"DataSourceExec: file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let n_partitions = match_.as_str().parse::().unwrap(); - return Ok(n_partitions); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let n_partitions = match_.as_str().parse::().unwrap(); + return Ok(n_partitions); } internal_err!("query contains no DataSourceExec") @@ -736,13 +736,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c2) | +--------------+ | 285 | +--------------+ - "###); + "); } assert_eq!(n_partitions, actual_partitions); @@ -775,13 +775,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // Compressed csv won't be scanned in parallel @@ -812,13 +812,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set @@ -843,10 +843,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -868,10 +868,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -944,17 +944,19 @@ mod tests { let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); assert_eq!(files.len(), 1); - assert!(files - .last() - .unwrap() - .as_ref() - .unwrap() - .path() - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".csv.gz")); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv.gz") + ); Ok(()) } @@ -983,17 +985,19 @@ mod tests { let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); assert_eq!(files.len(), 1); - assert!(files - .last() - .unwrap() - .as_ref() - .unwrap() - .path() - .file_name() - .unwrap() - .to_str() - .unwrap() - .ends_with(".csv")); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv") + ); Ok(()) } @@ -1032,10 +1036,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -1084,13 +1088,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------------+ - | sum(empty.column_1) | - +---------------------+ - | 10 | - +---------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------------+ + | sum(empty.column_1) | + +---------------------+ + | 10 | + +---------------------+ + ");} assert_eq!(n_partitions, actual_partitions); // Won't get partitioned if all files are empty @@ -1132,13 +1136,13 @@ mod tests { file_size }; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +-----------------------+ | sum(one_col.column_1) | +-----------------------+ | 50 | +-----------------------+ - "###); + "); } assert_eq!(expected_partitions, actual_partitions); @@ -1171,13 +1175,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------+ - | sum_of_5_cols | - +---------------+ - | 15 | - +---------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------+ + | sum_of_5_cols | + +---------------+ + | 15 | + +---------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -1191,7 +1195,9 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1230,7 +1236,8 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1499,7 +1506,7 @@ mod tests { // Create a temp file with a .csv suffix so the reader accepts it let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; // ensures path ends with .csv - // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. + // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. write!(tmp, "a,b,c\n1,2\n3,4,5\n")?; let path = tmp.path().to_str().unwrap().to_string(); @@ -1529,4 +1536,32 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_infer_schema_with_zero_max_records() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let root = format!("{}/csv", arrow_test_data()); + let format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(0); // Set to 0 to disable inference + let exec = scan_format( + &state, + &format, + None, + &root, + "aggregate_test_100.csv", + None, + None, + ) + .await?; + + // related to https://github.com/apache/datafusion/issues/19417 + for f in exec.schema().fields() { + assert_eq!(*f.data_type(), DataType::Utf8); + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 34d3d64f07fb2..4d5ed34399693 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -36,7 +36,7 @@ mod tests { BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; use datafusion_datasource::file_format::FileFormat; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::compute::concat_batches; use arrow::datatypes::{DataType, Field}; @@ -187,11 +187,11 @@ mod tests { let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let count = match_.as_str().parse::().unwrap(); - return Ok(count); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); } internal_err!("Query contains no Exec: file_groups") @@ -218,13 +218,13 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_num_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r###" - +----------------------+ - | sum(json_parallel.a) | - +----------------------+ - | -7 | - +----------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r" + +----------------------+ + | sum(json_parallel.a) | + +----------------------+ + | -7 | + +----------------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -249,10 +249,10 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result),@r" + ++ + ++ + "); Ok(()) } @@ -284,15 +284,15 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); - assert_snapshot!(batches_to_string(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - | 11 | 12 | 13 | 14 | 15 | - +----+----+----+----+----+ - "###); + assert_snapshot!(batches_to_string(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + | 11 | 12 | 13 | 14 | 15 | + +----+----+----+----+----+ + "); Ok(()) } @@ -324,14 +324,14 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); - insta::assert_snapshot!(fmt_batches(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - +----+----+----+----+----+ - "###); + insta::assert_snapshot!(fmt_batches(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + +----+----+----+----+----+ + "); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 4881783eeba69..6bbb63f6a17ad 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -39,8 +39,9 @@ pub(crate) mod test_util { use arrow_schema::SchemaRef; use datafusion_catalog::Session; use datafusion_common::Result; + use datafusion_datasource::TableSchema; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::{file_format::FileFormat, PartitionedFile}; + use datafusion_datasource::{PartitionedFile, file_format::FileFormat}; use datafusion_execution::object_store::ObjectStoreUrl; use std::sync::Arc; @@ -66,31 +67,34 @@ pub(crate) mod test_util { .await? }; + let table_schema = TableSchema::new(file_schema.clone(), vec![]); + let statistics = format .infer_stats(state, &store, file_schema.clone(), &meta) .await?; - let file_groups = vec![vec![PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }] - .into()]; + let file_groups = vec![ + vec![PartitionedFile { + object_meta: meta, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + }] + .into(), + ]; let exec = format .create_physical_plan( state, FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - format.file_source(), + format.file_source(table_schema), ) .with_file_groups(file_groups) .with_statistics(statistics) - .with_projection_indices(projection) + .with_projection_indices(projection)? .with_limit(limit) .build(), ) @@ -131,7 +135,10 @@ mod tests { .write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index e78c5f09553cc..146c5f6f5fd0f 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -25,9 +25,9 @@ use crate::datasource::file_format::avro::AvroFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; +use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::listing::ListingTableUrl; use crate::datasource::{file_format::csv::CsvFormat, listing::ListingOptions}; use crate::error::Result; @@ -523,6 +523,12 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Specify how many rows to read for schema inference + pub fn schema_infer_max_records(mut self, schema_infer_max_records: usize) -> Self { + self.schema_infer_max_records = schema_infer_max_records; + self + } } #[async_trait] diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 52c5393e10319..44cf09c1ae46e 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -107,8 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -120,6 +120,7 @@ mod tests { use arrow::array::RecordBatch; use arrow_schema::Schema; use datafusion_catalog::Session; + use datafusion_common::ScalarValue::Utf8; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_nanosecond_array, @@ -127,7 +128,6 @@ mod tests { use datafusion_common::config::{ParquetOptions, TableParquetOptions}; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_string; - use datafusion_common::ScalarValue::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; @@ -135,33 +135,33 @@ mod tests { use datafusion_datasource_parquet::{ ParquetFormat, ParquetFormatFactory, ParquetSink, }; + use datafusion_execution::TaskContext; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use crate::test_util::bounded_stream; use arrow::array::{ - types::Int32Type, Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, - StringArray, + Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource_parquet::metadata::DFParquetMetadata; - use futures::stream::BoxStream; use futures::StreamExt; + use futures::stream::BoxStream; use insta::assert_snapshot; - use object_store::local::LocalFileSystem; use object_store::ObjectMeta; + use object_store::local::LocalFileSystem; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, - PutMultipartOptions, PutOptions, PutPayload, PutResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; - use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; + use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::file::metadata::{ KeyValue, ParquetColumnIndex, ParquetMetaData, ParquetOffsetIndex, }; @@ -724,7 +724,7 @@ mod tests { // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -770,10 +770,9 @@ mod tests { exec.partition_statistics(None)?.num_rows, Precision::Exact(8) ); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); @@ -931,7 +930,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", + format!("{values:?}") + ); Ok(()) } @@ -1204,10 +1206,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } @@ -1233,10 +1235,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3333b70676203..93d77e10ba23c 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -113,8 +113,8 @@ mod tests { use crate::prelude::*; use crate::{ datasource::{ - file_format::csv::CsvFormat, file_format::json::JsonFormat, - provider_as_source, DefaultTableSource, MemTable, + DefaultTableSource, MemTable, file_format::csv::CsvFormat, + file_format::json::JsonFormat, provider_as_source, }, execution::options::ArrowReadOptions, test::{ @@ -129,33 +129,26 @@ mod tests { ListingOptions, ListingTable, ListingTableConfig, SchemaSource, }; use datafusion_common::{ - assert_contains, plan_err, + DataFusionError, Result, ScalarValue, assert_contains, stats::Precision, test_util::{batches_to_string, datafusion_test_data}, - ColumnStatistics, DataFusionError, Result, ScalarValue, }; + use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, - }; - use datafusion_datasource::ListingTableUrl; use datafusion_expr::dml::InsertOp; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; - use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::empty::EmptyExec; - use datafusion_physical_plan::{collect, ExecutionPlanProperties}; - use rstest::rstest; + use datafusion_physical_plan::{ExecutionPlanProperties, collect}; use std::collections::HashMap; use std::io::Write; use std::sync::Arc; use tempfile::TempDir; use url::Url; - const DUMMY_NULL_COUNT: Precision = Precision::Exact(42); - /// Creates a test schema with standard field types used in tests fn create_test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -257,7 +250,7 @@ mod tests { ); assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -289,32 +282,36 @@ mod tests { // sort expr, but non column ( vec![vec![col("int_col").add(lit(1)).sort(true, true)]], - Ok(vec![[PhysicalSortExpr { - expr: binary( - physical_col("int_col", &schema).unwrap(), - Operator::Plus, - physical_lit(1), - &schema, - ) - .unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }] - .into()]), + Ok(vec![ + [PhysicalSortExpr { + expr: binary( + physical_col("int_col", &schema).unwrap(), + Operator::Plus, + physical_lit(1), + &schema, + ) + .unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into(), + ]), ), // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![[PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }] - .into()]), + Ok(vec![ + [PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into(), + ]), ), // ok with two columns, different options ( @@ -322,19 +319,21 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![[ - PhysicalSortExpr::new_default( - physical_col("string_col", &schema).unwrap(), - ) - .asc() - .nulls_last(), - PhysicalSortExpr::new_default( - physical_col("int_col", &schema).unwrap(), - ) - .desc() - .nulls_first(), - ] - .into()]), + Ok(vec![ + [ + PhysicalSortExpr::new_default( + physical_col("string_col", &schema).unwrap(), + ) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default( + physical_col("int_col", &schema).unwrap(), + ) + .desc() + .nulls_first(), + ] + .into(), + ]), ), ]; @@ -453,9 +452,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -488,9 +487,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -538,9 +537,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -731,8 +730,8 @@ mod tests { } #[tokio::test] - async fn test_insert_into_append_new_parquet_files_invalid_session_fails( - ) -> Result<()> { + async fn test_insert_into_append_new_parquet_files_invalid_session_fails() + -> Result<()> { let mut config_map: HashMap = HashMap::new(); config_map.insert( "datafusion.execution.parquet.compression".into(), @@ -746,7 +745,10 @@ mod tests { ) .await .expect_err("Example should fail!"); - assert_eq!(e.strip_backtrace(), "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)"); + assert_eq!( + e.strip_backtrace(), + "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)" + ); Ok(()) } @@ -873,13 +875,13 @@ mod tests { let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the records in the table let batches = session_ctx @@ -888,13 +890,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Assert that `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -909,13 +911,13 @@ mod tests { // Again, execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the contents of the table let batches = session_ctx @@ -924,13 +926,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 40 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 40 | + +-------+ + ");} // Assert that another `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -988,15 +990,15 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-----+-----+---+ - | a | b | c | - +-----+-----+---+ - | foo | bar | 1 | - | foo | bar | 2 | - | foo | bar | 3 | - +-----+-----+---+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-----+-----+---+ + | a | b | c | + +-----+-----+---+ + | foo | bar | 1 | + | foo | bar | 2 | + | foo | bar | 3 | + +-----+-----+---+ + ");} Ok(()) } @@ -1307,10 +1309,10 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), 1); + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert_eq!(result.file_groups.len(), 1); - let files = file_list[0].clone(); + let files = result.file_groups[0].clone(); assert_eq!( files @@ -1397,7 +1399,7 @@ mod tests { // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec_enabled.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -1416,7 +1418,9 @@ mod tests { ]; for (format, batch_size, soft_max_rows, expected_files) in test_cases { - println!("Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}"); + println!( + "Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}" + ); let mut config_map = HashMap::new(); config_map.insert( @@ -1449,33 +1453,10 @@ mod tests { } #[tokio::test] - async fn test_statistics_mapping_with_custom_factory() -> Result<()> { + async fn test_basic_table_scan() -> Result<()> { let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - false, - // NullStatsAdapterFactory sets column_statistics null_count to DUMMY_NULL_COUNT - Arc::new(NullStatsAdapterFactory {}), - )?; - let (groups, stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - - assert_eq!(stats.column_statistics[0].null_count, DUMMY_NULL_COUNT); - for g in groups { - if let Some(s) = g.file_statistics(None) { - assert_eq!(s.column_statistics[0].null_count, DUMMY_NULL_COUNT); - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_statistics_mapping_with_default_factory() -> Result<()> { - let ctx = SessionContext::new(); - - // Create a table without providing a custom schema adapter factory - // This should fall back to using DefaultSchemaAdapterFactory + // Test basic table creation and scanning let path = "table/file.json"; register_test_store(&ctx, &[(path, 10)]); @@ -1487,222 +1468,20 @@ mod tests { let config = ListingTableConfig::new(table_path) .with_listing_options(opt) .with_schema(Arc::new(schema)); - // Note: NOT calling .with_schema_adapter_factory() to test default behavior let table = ListingTable::try_new(config)?; - // Verify that no custom schema adapter factory is set - assert!(table.schema_adapter_factory().is_none()); - - // The scan should work correctly with the default schema adapter + // The scan should work correctly let scan_result = table.scan(&ctx.state(), None, &[], None).await; - assert!( - scan_result.is_ok(), - "Scan should succeed with default schema adapter" - ); + assert!(scan_result.is_ok(), "Scan should succeed"); - // Verify that the default adapter handles basic schema compatibility - let (groups, _stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + // Verify file listing works + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; assert!( - !groups.is_empty(), - "Should list files successfully with default adapter" + !result.file_groups.is_empty(), + "Should list files successfully" ); Ok(()) } - - #[rstest] - #[case(MapSchemaError::TypeIncompatible, "Cannot map incompatible types")] - #[case(MapSchemaError::GeneralFailure, "Schema adapter mapping failed")] - #[case( - MapSchemaError::InvalidProjection, - "Invalid projection in schema mapping" - )] - #[tokio::test] - async fn test_schema_adapter_map_schema_errors( - #[case] error_type: MapSchemaError, - #[case] expected_error_msg: &str, - ) -> Result<()> { - let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - false, - Arc::new(FailingMapSchemaAdapterFactory { error_type }), - )?; - - // The error should bubble up from the scan operation when schema mapping fails - let scan_result = table.scan(&ctx.state(), None, &[], None).await; - - assert!(scan_result.is_err()); - let error_msg = scan_result.unwrap_err().to_string(); - assert!( - error_msg.contains(expected_error_msg), - "Expected error containing '{expected_error_msg}', got: {error_msg}" - ); - - Ok(()) - } - - // Test that errors during file listing also bubble up correctly - #[tokio::test] - async fn test_schema_adapter_error_during_file_listing() -> Result<()> { - let ctx = SessionContext::new(); - let table = create_test_listing_table_with_json_and_adapter( - &ctx, - true, - Arc::new(FailingMapSchemaAdapterFactory { - error_type: MapSchemaError::TypeIncompatible, - }), - )?; - - // The error should bubble up from list_files_for_scan when collecting statistics - let list_result = table.list_files_for_scan(&ctx.state(), &[], None).await; - - assert!(list_result.is_err()); - let error_msg = list_result.unwrap_err().to_string(); - assert!( - error_msg.contains("Cannot map incompatible types"), - "Expected type incompatibility error during file listing, got: {error_msg}" - ); - - Ok(()) - } - - #[derive(Debug, Copy, Clone)] - enum MapSchemaError { - TypeIncompatible, - GeneralFailure, - InvalidProjection, - } - - #[derive(Debug)] - struct FailingMapSchemaAdapterFactory { - error_type: MapSchemaError, - } - - impl SchemaAdapterFactory for FailingMapSchemaAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(FailingMapSchemaAdapter { - schema: projected_table_schema, - error_type: self.error_type, - }) - } - } - - #[derive(Debug)] - struct FailingMapSchemaAdapter { - schema: SchemaRef, - error_type: MapSchemaError, - } - - impl SchemaAdapter for FailingMapSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - _file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - // Always fail with different error types based on the configured error_type - match self.error_type { - MapSchemaError::TypeIncompatible => { - plan_err!( - "Cannot map incompatible types: Boolean cannot be cast to Utf8" - ) - } - MapSchemaError::GeneralFailure => { - plan_err!("Schema adapter mapping failed due to internal error") - } - MapSchemaError::InvalidProjection => { - plan_err!("Invalid projection in schema mapping: column index out of bounds") - } - } - } - } - - #[derive(Debug)] - struct NullStatsAdapterFactory; - - impl SchemaAdapterFactory for NullStatsAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(NullStatsAdapter { - schema: projected_table_schema, - }) - } - } - - #[derive(Debug)] - struct NullStatsAdapter { - schema: SchemaRef, - } - - impl SchemaAdapter for NullStatsAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let projection = (0..file_schema.fields().len()).collect(); - Ok((Arc::new(NullStatsMapper {}), projection)) - } - } - - #[derive(Debug)] - struct NullStatsMapper; - - impl SchemaMapper for NullStatsMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - Ok(batch) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - Ok(stats - .iter() - .map(|s| { - let mut s = s.clone(); - s.null_count = DUMMY_NULL_COUNT; - s - }) - .collect()) - } - } - - /// Helper function to create a test ListingTable with JSON format and custom schema adapter factory - fn create_test_listing_table_with_json_and_adapter( - ctx: &SessionContext, - collect_stat: bool, - schema_adapter_factory: Arc, - ) -> Result { - let path = "table/file.json"; - register_test_store(ctx, &[(path, 10)]); - - let format = JsonFormat::default(); - let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(collect_stat); - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse("test:///table/")?; - - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(Arc::new(schema)) - .with_schema_adapter_factory(schema_adapter_factory); - - ListingTable::try_new(config) - } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index f98297d0e3f7f..3ca388af0c4c1 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -28,8 +28,8 @@ use crate::datasource::listing::{ use crate::execution::context::SessionState; use arrow::datatypes::DataType; -use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema}; -use datafusion_common::{config_datafusion_err, Result}; +use datafusion_common::{Result, config_datafusion_err}; +use datafusion_common::{ToDFSchema, arrow_datafusion_err, plan_err}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -190,6 +190,16 @@ impl TableProviderFactory for ListingTableFactory { .with_definition(cmd.definition.clone()) .with_constraints(cmd.constraints.clone()) .with_column_defaults(cmd.column_defaults.clone()); + + // Pre-warm statistics cache if collect_statistics is enabled + if session_state.config().collect_statistics() { + let filters = &[]; + let limit = None; + if let Err(e) = table.list_files_for_scan(state, filters, limit).await { + log::warn!("Failed to pre-warm statistics cache: {e}"); + } + } + Ok(Arc::new(table)) } } @@ -205,19 +215,23 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { + use super::*; + use crate::{ + datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + test_util::parquet_test_data, + }; + use datafusion_execution::cache::CacheAccessor; + use datafusion_execution::cache::cache_manager::CacheManagerConfig; + use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use glob::Pattern; use std::collections::HashMap; use std::fs; use std::path::PathBuf; - use super::*; - use crate::{ - datasource::file_format::csv::CsvFormat, execution::context::SessionContext, - }; - use datafusion_common::parsers::CompressionTypeVariant; - use datafusion_common::{Constraints, DFSchema, TableReference}; + use datafusion_common::{DFSchema, TableReference}; #[tokio::test] async fn test_create_using_non_std_file_ext() { @@ -231,22 +245,14 @@ mod tests { let context = SessionContext::new(); let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(HashMap::from([("format.has_header".into(), "true".into())])) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -272,22 +278,14 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -317,22 +315,14 @@ mod tests { options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); options.insert("format.compression".into(), "gzip".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -369,22 +359,14 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -413,22 +395,13 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: String::from(path.to_str().unwrap()), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + String::from(path.to_str().unwrap()), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -453,22 +426,13 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -494,22 +458,13 @@ mod tests { let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: dir.path().to_str().unwrap().to_string(), - file_type: "parquet".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - or_replace: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::new(), - constraints: Constraints::default(), - column_defaults: HashMap::new(), - }; + dir.path().to_str().unwrap().to_string(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -519,4 +474,75 @@ mod tests { let listing_options = listing_table.options(); assert!(listing_options.table_partition_cols.is_empty()); } + + #[tokio::test] + async fn test_statistics_cache_prewarming() { + let factory = ListingTableFactory::new(); + + let location = PathBuf::from(parquet_test_data()) + .join("alltypes_tiny_pages_plain.parquet") + .to_string_lossy() + .to_string(); + + // Test with collect_statistics enabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location.clone(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert!( + file_statistics_cache.len() > 0, + "Statistics cache should be pre-warmed when collect_statistics is enabled" + ); + + // Test with collect_statistics disabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = false; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location, + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert_eq!( + file_statistics_cache.len(), + 0, + "Statistics cache should not be pre-warmed when collect_statistics is disabled" + ); + } } diff --git a/datafusion/core/src/datasource/memory_test.rs b/datafusion/core/src/datasource/memory_test.rs index c16837c73b4f1..c7721cafb02ea 100644 --- a/datafusion/core/src/datasource/memory_test.rs +++ b/datafusion/core/src/datasource/memory_test.rs @@ -19,7 +19,7 @@ mod tests { use crate::datasource::MemTable; - use crate::datasource::{provider_as_source, DefaultTableSource}; + use crate::datasource::{DefaultTableSource, provider_as_source}; use crate::physical_plan::collect; use crate::prelude::SessionContext; use arrow::array::{AsArray, Int32Array}; @@ -29,8 +29,8 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::dml::InsertOp; use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::dml::InsertOp; use futures::StreamExt; use std::collections::HashMap; use std::sync::Arc; @@ -329,12 +329,11 @@ mod tests { ); let col = batch.column(0).as_primitive::(); assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len()); - let val = col - .iter() + + col.iter() .next() .expect("had value") - .expect("expected non null"); - val + .expect("expected non null") } // Test inserting a single batch of data into a single partition diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 37b9663111a53..aefda64d39367 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -31,7 +31,7 @@ mod view_test; // backwards compatibility pub use self::default_table_source::{ - provider_as_source, source_as_provider, DefaultTableSource, + DefaultTableSource, provider_as_source, source_as_provider, }; pub use self::memory::MemTable; pub use self::view::ViewTable; @@ -53,32 +53,34 @@ pub use datafusion_physical_expr::create_ordering; mod tests { use crate::prelude::SessionContext; - use ::object_store::{path::Path, ObjectMeta}; + use ::object_store::{ObjectMeta, path::Path}; use arrow::{ - array::{Int32Array, StringArray}, + array::Int32Array, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; - use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_common::{ + Result, ScalarValue, + test_util::batches_to_sort_string, + tree_node::{Transformed, TransformedResult, TreeNode}, + }; use datafusion_datasource::{ - file::FileSource, - file_scan_config::FileScanConfigBuilder, - schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, - SchemaMapper, - }, - source::DataSourceExec, - PartitionedFile, + PartitionedFile, file_scan_config::FileScanConfigBuilder, source::DataSourceExec, }; use datafusion_datasource_parquet::source::ParquetSource; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_adapter::{ + PhysicalExprAdapter, PhysicalExprAdapterFactory, + }; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::collect; use std::{fs, sync::Arc}; use tempfile::TempDir; #[tokio::test] - async fn can_override_schema_adapter() { - // Test shows that SchemaAdapter can add a column that doesn't existing in the - // record batches returned from parquet. This can be useful for schema evolution + async fn can_override_physical_expr_adapter() { + // Test shows that PhysicalExprAdapter can add a column that doesn't exist in the + // record batches returned from parquet. This can be useful for schema evolution // where older files may not have all columns. use datafusion_execution::object_store::ObjectStoreUrl; @@ -124,16 +126,12 @@ mod tests { let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})) - .unwrap(); - let base_conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema, - source, - ) - .with_file(partitioned_file) - .build(); + let source = Arc::new(ParquetSource::new(Arc::clone(&schema))); + let base_conf = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(partitioned_file) + .with_expr_adapter(Some(Arc::new(TestPhysicalExprAdapterFactory))) + .build(); let parquet_exec = DataSourceExec::from_data_source(base_conf); @@ -141,134 +139,52 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let read = collect(parquet_exec, task_ctx).await.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" + insta::assert_snapshot!(batches_to_sort_string(&read),@r" +----+--------------+ | id | extra_column | +----+--------------+ | 1 | foo | +----+--------------+ - "###); - } - - #[test] - fn default_schema_adapter() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ]); - - // file has a subset of the table schema fields and different type - let file_schema = Schema::new(vec![ - Field::new("c", DataType::Float64, true), // not in table schema - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![1]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - let mapped_batch = mapper.map_batch(file_batch).unwrap(); - - // the mapped batch has the correct schema and the "b" column has been cast to Utf8 - let expected_batch = record_batch!( - ("a", Int32, vec![None, None]), // missing column filled with nulls - ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed - ) - .unwrap(); - assert_eq!(mapped_batch, expected_batch); - } - - #[test] - fn default_schema_adapter_non_nullable_columns() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), // "a"" is declared non nullable - Field::new("b", DataType::Utf8, true), - ]); - let file_schema = Schema::new(vec![ - // since file doesn't have "a" it will be filled with nulls - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![0]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - // Mapping fails because it tries to fill in a non-nullable column with nulls - let err = mapper.map_batch(file_batch).unwrap_err().to_string(); - assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + "); } #[derive(Debug)] - struct TestSchemaAdapterFactory; + struct TestPhysicalExprAdapterFactory; - impl SchemaAdapterFactory for TestSchemaAdapterFactory { + impl PhysicalExprAdapterFactory for TestPhysicalExprAdapterFactory { fn create( &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(TestSchemaAdapter { - table_schema: projected_table_schema, + _logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + Arc::new(TestPhysicalExprAdapter { + physical_file_schema, }) } } - struct TestSchemaAdapter { - /// Schema for the table - table_schema: SchemaRef, + #[derive(Debug)] + struct TestPhysicalExprAdapter { + physical_file_schema: SchemaRef, } - impl SchemaAdapter for TestSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - Some(file_schema.fields.find(field.name())?.0) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if self.table_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); + impl PhysicalExprAdapter for TestPhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|e| { + if let Some(column) = e.as_any().downcast_ref::() { + // If column is "extra_column" and missing from physical schema, inject "foo" + if column.name() == "extra_column" + && self.physical_file_schema.index_of("extra_column").is_err() + { + return Ok(Transformed::yes(Arc::new(Literal::new( + ScalarValue::Utf8(Some("foo".to_string())), + )) + as Arc)); + } } - } - - Ok((Arc::new(TestSchemaMapping {}), projection)) - } - } - - #[derive(Debug)] - struct TestSchemaMapping {} - - impl SchemaMapper for TestSchemaMapping { - fn map_batch( - &self, - batch: RecordBatch, - ) -> datafusion_common::Result { - let f1 = Field::new("id", DataType::Int32, true); - let f2 = Field::new("extra_column", DataType::Utf8, true); - - let schema = Arc::new(Schema::new(vec![f1, f2])); - - let extra_column = Arc::new(StringArray::from(vec!["foo"])); - let mut new_columns = batch.columns().to_vec(); - new_columns.push(extra_column); - - Ok(RecordBatch::try_new(schema, new_columns).unwrap()) - } - - fn map_column_statistics( - &self, - _file_col_statistics: &[datafusion_common::ColumnStatistics], - ) -> datafusion_common::Result> { - unimplemented!() + Ok(Transformed::no(e)) + }) + .data() } } } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 9068c9758179d..2954a47403299 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -31,21 +31,21 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::datatypes::{DataType, Field, SchemaBuilder}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{test_util, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, test_util}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::PartitionedFile; - use datafusion_datasource_avro::source::AvroSource; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_avro::AvroFormat; + use datafusion_datasource_avro::source::AvroSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_plan::ExecutionPlan; use datafusion_datasource::source::DataSourceExec; use futures::StreamExt; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use url::Url; @@ -81,15 +81,11 @@ mod tests { .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await?; - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file(meta.into()) - .with_projection_indices(Some(vec![0, 1, 2])) - .build(); + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(meta.into()) + .with_projection_indices(Some(vec![0, 1, 2]))? + .build(); let source_exec = DataSourceExec::from_data_source(conf); assert_eq!( @@ -109,20 +105,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+ - | id | bool_col | tinyint_col | - +----+----------+-------------+ - | 4 | true | 0 | - | 5 | false | 1 | - | 6 | true | 0 | - | 7 | false | 1 | - | 2 | true | 0 | - | 3 | false | 1 | - | 0 | true | 0 | - | 1 | false | 1 | - +----+----------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + | 3 | false | 1 | + | 0 | true | 0 | + | 1 | false | 1 | + +----+----------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -157,10 +153,10 @@ mod tests { // Include the missing column in the projection let projection = Some(vec![0, 1, 2, actual_schema.fields().len()]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file(meta.into()) - .with_projection_indices(projection) + .with_projection_indices(projection)? .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -182,20 +178,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+-------------+ - | id | bool_col | tinyint_col | missing_col | - +----+----------+-------------+-------------+ - | 4 | true | 0 | | - | 5 | false | 1 | | - | 6 | true | 0 | | - | 7 | false | 1 | | - | 2 | true | 0 | | - | 3 | false | 1 | | - | 0 | true | 0 | | - | 1 | false | 1 | | - +----+----------+-------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+-------------+ + | id | bool_col | tinyint_col | missing_col | + +----+----------+-------------+-------------+ + | 4 | true | 0 | | + | 5 | false | 1 | | + | 6 | true | 0 | | + | 7 | false | 1 | | + | 2 | true | 0 | | + | 3 | false | 1 | | + | 0 | true | 0 | | + | 1 | false | 1 | | + +----+----------+-------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -227,13 +223,16 @@ mod tests { partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let projection = Some(vec![0, 1, file_schema.fields().len(), 2]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let table_schema = TableSchema::new( + file_schema.clone(), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = Arc::new(AvroSource::new(table_schema.clone())); + let conf = FileScanConfigBuilder::new(object_store_url, source) // select specific columns of the files as well as the partitioning // column which is supposed to be the last column in the table schema. - .with_projection_indices(projection) + .with_projection_indices(projection)? .with_file(partitioned_file) - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -256,20 +255,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+------------+-------------+ - | id | bool_col | date | tinyint_col | - +----+----------+------------+-------------+ - | 4 | true | 2021-10-26 | 0 | - | 5 | false | 2021-10-26 | 1 | - | 6 | true | 2021-10-26 | 0 | - | 7 | false | 2021-10-26 | 1 | - | 2 | true | 2021-10-26 | 0 | - | 3 | false | 2021-10-26 | 1 | - | 0 | true | 2021-10-26 | 0 | - | 1 | false | 2021-10-26 | 1 | - +----+----------+------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+------------+-------------+ + | id | bool_col | date | tinyint_col | + +----+----------+------------+-------------+ + | 4 | true | 2021-10-26 | 0 | + | 5 | false | 2021-10-26 | 1 | + | 6 | true | 2021-10-26 | 0 | + | 7 | false | 2021-10-26 | 1 | + | 2 | true | 2021-10-26 | 0 | + | 3 | false | 2021-10-26 | 1 | + | 0 | true | 2021-10-26 | 0 | + | 1 | false | 2021-10-26 | 1 | + +----+----------+------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 4f46a57d8b137..0e40ed2df2066 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -29,18 +29,21 @@ mod tests { use std::io::Write; use std::sync::Arc; + use datafusion_datasource::TableSchema; use datafusion_datasource_csv::CsvFormat; use object_store::ObjectStore; + use crate::datasource::file_format::FileFormat; use crate::prelude::CsvReadOptions; use crate::prelude::SessionContext; use crate::test::partitioned_file_groups; + use datafusion_common::config::CsvOptions; use datafusion_common::test_util::arrow_test_data; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, Result}; + use datafusion_common::{Result, assert_batches_eq}; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; + use datafusion_physical_plan::metrics::MetricsSet; #[cfg(feature = "compression")] use datafusion_datasource::file_compression_type::FileCompressionType; @@ -94,32 +97,39 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_file_compression_type(file_compression_type) - .with_newlines_in_values(false) - .with_projection_indices(Some(vec![0, 2, 4])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type) + .with_projection_indices(Some(vec![0, 2, 4]))? + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); @@ -131,17 +141,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+-----+------------+ - | c1 | c3 | c5 | - +----+-----+------------+ - | c | 1 | 2033001162 | - | d | -40 | 706441268 | - | b | 29 | 994303988 | - | a | -85 | 1171968280 | - | b | -82 | 1824882165 | - +----+-----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+-----+------------+ + | c1 | c3 | c5 | + +----+-----+------------+ + | c | 1 | 2033001162 | + | d | -40 | 706441268 | + | b | 29 | 994303988 | + | a | -85 | 1171968280 | + | b | -82 | 1824882165 | + +----+-----+------------+ + ");} Ok(()) } @@ -158,6 +168,8 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -165,26 +177,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_projection_indices(Some(vec![4, 0, 2])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_projection_indices(Some(vec![4, 0, 2]))? + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -194,17 +211,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +------------+----+-----+ - | c5 | c1 | c3 | - +------------+----+-----+ - | 2033001162 | c | 1 | - | 706441268 | d | -40 | - | 994303988 | b | 29 | - | 1171968280 | a | -85 | - | 1824882165 | b | -82 | - +------------+----+-----+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +------------+----+-----+ + | c5 | c1 | c3 | + +------------+----+-----+ + | 2033001162 | c | 1 | + | 706441268 | d | -40 | + | 994303988 | b | 29 | + | 1171968280 | a | -85 | + | 1824882165 | b | -82 | + +------------+----+-----+ + ");} Ok(()) } @@ -221,6 +238,7 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; use futures::StreamExt; let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); @@ -230,26 +248,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(13, csv.schema().fields().len()); @@ -259,17 +282,17 @@ mod tests { assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | - | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | - | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | - | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | - | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | + | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | + | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | + | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | + | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + ");} Ok(()) } @@ -287,32 +310,39 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); assert_eq!(14, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(14, csv.schema().fields().len()); @@ -341,6 +371,7 @@ mod tests { file_compression_type: FileCompressionType, ) -> Result<()> { use datafusion_common::ScalarValue; + use datafusion_datasource::TableSchema; let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -348,12 +379,13 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let mut file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -362,19 +394,25 @@ mod tests { let num_file_schema_fields = file_schema.fields().len(); - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) - // We should be able to project on the partition column - // Which is supposed to be after the file fields - .with_projection_indices(Some(vec![0, num_file_schema_fields])) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + // We should be able to project on the partition column + // Which is supposed to be after the file fields + .with_projection_indices(Some(vec![0, num_file_schema_fields]))? + .build(); // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway @@ -388,17 +426,17 @@ mod tests { assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+------------+ - | c1 | date | - +----+------------+ - | c | 2021-10-26 | - | d | 2021-10-26 | - | b | 2021-10-26 | - | a | 2021-10-26 | - | b | 2021-10-26 | - +----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+------------+ + | c1 | date | + +----+------------+ + | c | 2021-10-26 | + | d | 2021-10-26 | + | b | 2021-10-26 | + | a | 2021-10-26 | + | b | 2021-10-26 | + +----+------------+ + ");} let metrics = csv.metrics().expect("doesn't found metrics"); let time_elapsed_processing = get_value(&metrics, "time_elapsed_processing"); @@ -452,26 +490,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), ) .unwrap(); - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); let csv = DataSourceExec::from_data_source(config); let it = csv.execute(0, task_ctx).unwrap(); @@ -527,14 +570,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result), @r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); } #[tokio::test] @@ -556,14 +599,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result),@r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result),@r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); let e = session_ctx .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n'))) @@ -572,7 +615,10 @@ mod tests { .collect() .await .unwrap_err(); - assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2") + assert_eq!( + e.strip_backtrace(), + "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2" + ) } #[tokio::test] @@ -593,22 +639,22 @@ mod tests { .await?; let df = ctx.sql(r#"select * from t1"#).await?.collect().await?; - assert_snapshot!(batches_to_string(&df),@r###" - +------+--------+ - | col1 | col2 | - +------+--------+ - | id0 | value0 | - | id1 | value1 | - | id2 | value2 | - | id3 | value3 | - +------+--------+ - "###); + assert_snapshot!(batches_to_string(&df),@r" + +------+--------+ + | col1 | col2 | + +------+--------+ + | id0 | value0 | + | id1 | value1 | + | id2 | value2 | + | id3 | value3 | + +------+--------+ + "); Ok(()) } #[tokio::test] - async fn test_create_external_table_with_terminator_with_newlines_in_values( - ) -> Result<()> { + async fn test_create_external_table_with_terminator_with_newlines_in_values() + -> Result<()> { let ctx = SessionContext::new(); ctx.sql(r#" CREATE EXTERNAL TABLE t1 ( @@ -658,7 +704,10 @@ mod tests { ) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index f7d5c710bf48a..8de6a60258f08 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -34,9 +34,9 @@ mod tests { use crate::execution::SessionState; use crate::prelude::{CsvReadOptions, NdJsonReadOptions, SessionContext}; use crate::test::partitioned_file_groups; + use datafusion_common::Result; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::Result; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource_json::JsonFormat; @@ -51,9 +51,9 @@ mod tests { use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use tempfile::TempDir; use url::Url; @@ -69,11 +69,13 @@ mod tests { let store = state.runtime_env().object_store(&store_url).unwrap(); let filename = "1.json"; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), work_dir, ) @@ -104,11 +106,13 @@ mod tests { ctx.register_object_store(&url, store.clone()); let filename = "1.json"; let tmp_dir = TempDir::new()?; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), tmp_dir.path(), ) @@ -138,16 +142,16 @@ mod tests { let frame = ctx.read_json(path, read_options).await.unwrap(); let results = frame.collect().await.unwrap(); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r###" - +-----+------------------+---------------+------+ - | a | b | c | d | - +-----+------------------+---------------+------+ - | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | - | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | - | 2 | [2.0, , -6.1] | [false, ] | text | - | | | | | - +-----+------------------+---------------+------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r" + +-----+------------------+---------------+------+ + | a | b | c | d | + +-----+------------------+---------------+------+ + | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | + | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | + | 2 | [2.0, , -6.1] | [false, ] | text | + | | | | | + +-----+------------------+---------------+------+ + ");} Ok(()) } @@ -176,8 +180,8 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -251,8 +255,8 @@ mod tests { let file_schema = Arc::new(builder.finish()); let missing_field_idx = file_schema.fields.len() - 1; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -294,10 +298,11 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_projection_indices(Some(vec![0, 2])) + .unwrap() .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -342,10 +347,10 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) - .with_projection_indices(Some(vec![3, 0, 2])) + .with_projection_indices(Some(vec![3, 0, 2]))? .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -494,7 +499,10 @@ mod tests { .write_json(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 1ac292e260fdf..04c8ea129d05c 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -43,146 +43,11 @@ pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; pub use datafusion_datasource::file_groups::FileGroupPartitioner; pub use datafusion_datasource::file_scan_config::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, - FileScanConfigBuilder, + FileScanConfig, FileScanConfigBuilder, wrap_partition_type_in_dict, + wrap_partition_value_in_dict, }; pub use datafusion_datasource::file_sink_config::*; pub use datafusion_datasource::file_stream::{ FileOpenFuture, FileOpener, FileStream, OnError, }; - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::{ - cast::AsArray, - types::{Float32Type, Float64Type, UInt32Type}, - BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, - StringArray, UInt64Array, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SchemaRef; - - use crate::datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapterFactory, - }; - - #[test] - fn schema_mapping_map_batch() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt32, true), - Field::new("c3", DataType::Float64, true), - ])); - - let adapter = DefaultSchemaAdapterFactory - .create(table_schema.clone(), table_schema.clone()); - - let file_schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt64, true), - Field::new("c3", DataType::Float32, true), - ]); - - let (mapping, _) = adapter.map_schema(&file_schema).expect("map schema failed"); - - let c1 = StringArray::from(vec!["hello", "world"]); - let c2 = UInt64Array::from(vec![9_u64, 5_u64]); - let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], - ) - .unwrap(); - - let mapped_batch = mapping.map_batch(batch).unwrap(); - - assert_eq!(mapped_batch.schema(), table_schema); - assert_eq!(mapped_batch.num_columns(), 3); - assert_eq!(mapped_batch.num_rows(), 2); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c3 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "hello"); - assert_eq!(c1.value(1), "world"); - assert_eq!(c2.value(0), 9_u32); - assert_eq!(c2.value(1), 5_u32); - assert_eq!(c3.value(0), 2.0_f64); - assert_eq!(c3.value(1), 7.0_f64); - } - - #[test] - fn schema_adapter_map_schema_with_projection() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c0", DataType::Utf8, true), - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::Int32, true), - Field::new("c4", DataType::Float32, true), - ])); - - let file_schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("c1", DataType::Boolean, true), - Field::new("c2", DataType::Float32, true), - Field::new("c3", DataType::Binary, true), - Field::new("c4", DataType::Int64, true), - ]); - - let indices = vec![1, 2, 4]; - let schema = SchemaRef::from(table_schema.project(&indices).unwrap()); - let adapter = DefaultSchemaAdapterFactory.create(schema, table_schema.clone()); - let (mapping, projection) = adapter.map_schema(&file_schema).unwrap(); - - let id = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - let c1 = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); - let c2 = Float32Array::from(vec![Some(2.0_f32), Some(7.0_f32), Some(3.0_f32)]); - let c3 = BinaryArray::from_opt_vec(vec![ - Some(b"hallo"), - Some(b"danke"), - Some(b"super"), - ]); - let c4 = Int64Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![ - Arc::new(id), - Arc::new(c1), - Arc::new(c2), - Arc::new(c3), - Arc::new(c4), - ], - ) - .unwrap(); - let rows_num = batch.num_rows(); - let projected = batch.project(&projection).unwrap(); - let mapped_batch = mapping.map_batch(projected).unwrap(); - - assert_eq!( - mapped_batch.schema(), - Arc::new(table_schema.project(&indices).unwrap()) - ); - assert_eq!(mapped_batch.num_columns(), indices.len()); - assert_eq!(mapped_batch.num_rows(), rows_num); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c4 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "true"); - assert_eq!(c1.value(1), "false"); - assert_eq!(c1.value(2), "true"); - - assert_eq!(c2.value(0), 2.0_f64); - assert_eq!(c2.value(1), 7.0_f64); - assert_eq!(c2.value(2), 3.0_f64); - - assert_eq!(c4.value(0), 1.0_f32); - assert_eq!(c4.value(1), 2.0_f32); - assert_eq!(c4.value(2), 3.0_f32); - } -} diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 0ffb252a66052..4703b55ecc0de 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,7 +38,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, + ArrayRef, AsArray, Date64Array, Int8Array, Int32Array, Int64Array, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; @@ -48,7 +48,7 @@ mod tests { use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; - use datafusion_common::{assert_contains, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, assert_contains}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; @@ -60,7 +60,7 @@ mod tests { DefaultParquetFileReaderFactory, ParquetFileReaderFactory, ParquetFormat, }; use datafusion_execution::object_store::ObjectStoreUrl; - use datafusion_expr::{col, lit, when, Expr}; + use datafusion_expr::{Expr, col, lit, when}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::analyze::AnalyzeExec; use datafusion_physical_plan::collect; @@ -161,7 +161,7 @@ mod tests { .as_ref() .map(|p| logical2physical(p, &table_schema)); - let mut source = ParquetSource::default(); + let mut source = ParquetSource::new(table_schema); if let Some(predicate) = predicate { source = source.with_predicate(predicate); } @@ -186,23 +186,20 @@ mod tests { source = source.with_bloom_filter_on_read(false); } - source.with_schema(TableSchema::new(Arc::clone(&table_schema), vec![])) + Arc::new(source) } fn build_parquet_exec( &self, - file_schema: SchemaRef, file_group: FileGroup, source: Arc, ) -> Arc { - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .with_projection_indices(self.projection.clone()) - .build(); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .with_projection_indices(self.projection.clone()) + .unwrap() + .build(); DataSourceExec::from_data_source(base_config) } @@ -231,11 +228,8 @@ mod tests { // build a ParquetExec to return the results let parquet_source = self.build_file_source(Arc::clone(table_schema)); - let parquet_exec = self.build_parquet_exec( - Arc::clone(table_schema), - file_group.clone(), - Arc::clone(&parquet_source), - ); + let parquet_exec = + self.build_parquet_exec(file_group.clone(), Arc::clone(&parquet_source)); let analyze_exec = Arc::new(AnalyzeExec::new( false, @@ -243,7 +237,6 @@ mod tests { vec![MetricType::SUMMARY, MetricType::DEV], // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( - Arc::clone(table_schema), file_group.clone(), self.build_file_source(Arc::clone(table_schema)), ), @@ -313,7 +306,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); @@ -344,13 +337,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -371,7 +364,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -402,13 +395,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -433,7 +426,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1, c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -464,13 +457,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -495,7 +488,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c3.clone(), c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -526,13 +519,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -575,13 +568,13 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 10 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -605,7 +598,7 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ @@ -613,7 +606,7 @@ mod tests { | 4 | | 40 | | 5 | | 50 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -642,7 +635,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+----+----+ | c1 | c2 | c3 | +-----+----+----+ @@ -656,7 +649,7 @@ mod tests { | bar | | | | bar | | | +-----+----+----+ - "###); + "); } #[tokio::test] @@ -757,18 +750,18 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+----+ - | c1 | c3 | c2 | - +-----+----+----+ - | | | | - | | 10 | 1 | - | | 20 | | - | | 20 | 2 | - | Foo | 10 | | - | bar | | | - +-----+----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+----+ + | c1 | c3 | c2 | + +-----+----+----+ + | | | | + | | 10 | 1 | + | | 20 | | + | | 20 | 2 | + | Foo | 10 | | + | bar | | | + +-----+----+----+ + "); } #[tokio::test] @@ -789,14 +782,14 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+----+ | c1 | c3 | c2 | +----+----+----+ | | 10 | 1 | | | 20 | 2 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 4); @@ -832,7 +825,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+-----+ | c1 | c4 | +-----+-----+ @@ -843,7 +836,7 @@ mod tests { | bar | | | bar | | +-----+-----+ - "###); + "); } #[tokio::test] @@ -1056,18 +1049,18 @@ mod tests { // In a real query where this predicate was pushed down from a filter stage instead of created directly in the `DataSourceExec`, // the filter stage would be preserved as a separate execution plan stage so the actual query results would be as expected. - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | | - | | | - | | 1 | - | | 2 | - | Foo | | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | | + | | | + | | 1 | + | | 2 | + | Foo | | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1092,13 +1085,13 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+ | c1 | c2 | +----+----+ | | 1 | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 5); @@ -1152,7 +1145,7 @@ mod tests { .round_trip(vec![batch1, batch2, batch3, batch4]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +------+----+ | c1 | c2 | +------+----+ @@ -1169,7 +1162,7 @@ mod tests { | Foo2 | | | Foo3 | | +------+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // There are 4 rows pruned in each of batch2, batch3, and @@ -1201,14 +1194,14 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1231,15 +1224,15 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | 2 | - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | 2 | + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1264,7 +1257,7 @@ mod tests { ("c3", c3.clone()), ]); - // batch2: c3(int8), c2(int64), c1(string), c4(string) + // batch2: c3(date64), c2(int64), c1(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); let table_schema = Schema::new(vec![ @@ -1278,8 +1271,10 @@ mod tests { .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; - assert_contains!(read.unwrap_err().to_string(), - "Cannot cast file schema field c3 of type Date64 to table schema field of type Int8"); + assert_contains!( + read.unwrap_err().to_string(), + "Cannot cast column 'c3' from 'Date64' (physical data type) to 'Int8' (logical data type)" + ); } #[tokio::test] @@ -1329,7 +1324,7 @@ mod tests { async fn parquet_exec_with_int96_from_spark() -> Result<()> { // arrow-rs relies on the chrono library to convert between timestamps and strings, so // instead compare as Int64. The underlying type should be a PrimitiveArray of Int64 - // anyway, so this should be a zero-copy non-modifying cast at the SchemaAdapter. + // anyway, so this should be a zero-copy non-modifying cast. let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); let testdata = datafusion_common::test_util::parquet_test_data(); @@ -1550,8 +1545,7 @@ mod tests { ) -> Result<()> { let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_groups(file_groups) .build(); @@ -1653,23 +1647,27 @@ mod tests { ), ]); - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) - .with_file(partitioned_file) - // file has 10 cols so index 12 should be month and 13 should be day - .with_projection_indices(Some(vec![0, 1, 2, 12, 13])) - .with_table_partition_cols(vec![ - Field::new("year", DataType::Utf8, false), - Field::new("month", DataType::UInt8, false), - Field::new( + let table_schema = TableSchema::new( + Arc::clone(&schema), + vec![ + Arc::new(Field::new("year", DataType::Utf8, false)), + Arc::new(Field::new("month", DataType::UInt8, false)), + Arc::new(Field::new( "day", DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), ), false, - ), - ]) + )), + ], + ); + let source = Arc::new(ParquetSource::new(table_schema.clone())); + let config = FileScanConfigBuilder::new(object_store_url, source) + .with_file(partitioned_file) + // file has 10 cols so index 12 should be month and 13 should be day + .with_projection_indices(Some(vec![0, 1, 2, 12, 13])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -1684,20 +1682,20 @@ mod tests { let batch = results.next().await.unwrap()?; assert_eq!(batch.schema().as_ref(), &expected_schema); - assert_snapshot!(batches_to_string(&[batch]),@r###" - +----+----------+-------------+-------+-----+ - | id | bool_col | tinyint_col | month | day | - +----+----------+-------------+-------+-----+ - | 4 | true | 0 | 10 | 26 | - | 5 | false | 1 | 10 | 26 | - | 6 | true | 0 | 10 | 26 | - | 7 | false | 1 | 10 | 26 | - | 2 | true | 0 | 10 | 26 | - | 3 | false | 1 | 10 | 26 | - | 0 | true | 0 | 10 | 26 | - | 1 | false | 1 | 10 | 26 | - +----+----------+-------------+-------+-----+ - "###); + assert_snapshot!(batches_to_string(&[batch]),@r" + +----+----------+-------------+-------+-----+ + | id | bool_col | tinyint_col | month | day | + +----+----------+-------------+-------+-----+ + | 4 | true | 0 | 10 | 26 | + | 5 | false | 1 | 10 | 26 | + | 6 | true | 0 | 10 | 26 | + | 7 | false | 1 | 10 | 26 | + | 2 | true | 0 | 10 | 26 | + | 3 | false | 1 | 10 | 26 | + | 0 | true | 0 | 10 | 26 | + | 1 | false | 1 | 10 | 26 | + +----+----------+-------------+-------+-----+ + "); let batch = results.next().await; assert!(batch.is_none()); @@ -1731,8 +1729,7 @@ mod tests { let file_schema = Arc::new(Schema::empty()); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file(partitioned_file) .build(); @@ -1770,14 +1767,14 @@ mod tests { let metrics = rt.parquet_exec.metrics().unwrap(); - assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r###" - +-----+ - | int | - +-----+ - | 4 | - | 5 | - +-----+ - "###); + assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r" + +-----+ + | int | + +-----+ + | 4 | + | 5 | + +-----+ + "); let (page_index_pruned, page_index_matched) = get_pruning_metric(&metrics, "page_index_rows_pruned"); assert_eq!(page_index_pruned, 4); @@ -1823,14 +1820,14 @@ mod tests { let metrics = rt.parquet_exec.metrics().unwrap(); // assert the batches and some metrics - assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r###" - +-----+ - | c1 | - +-----+ - | Foo | - | zzz | - +-----+ - "###); + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+ + | c1 | + +-----+ + | Foo | + | zzz | + +-----+ + "); // pushdown predicates have eliminated all 4 bar rows and the // null row for 5 rows total @@ -1879,6 +1876,100 @@ mod tests { assert_contains!(&explain, "projection=[c1]"); } + #[tokio::test] + async fn parquet_exec_metrics_with_multiple_predicates() { + // Test that metrics are correctly calculated when multiple predicates + // are pushed down (connected with AND). This ensures we don't double-count + // rows when multiple predicates filter the data sequentially. + + // Create a batch with two columns: c1 (string) and c2 (int32) + // Total: 10 rows + let c1: ArrayRef = Arc::new(StringArray::from(vec![ + Some("foo"), // 0 - passes c1 filter, fails c2 filter (5 <= 10) + Some("bar"), // 1 - fails c1 filter + Some("bar"), // 2 - fails c1 filter + Some("baz"), // 3 - passes both filters (20 > 10) + Some("foo"), // 4 - passes both filters (12 > 10) + Some("bar"), // 5 - fails c1 filter + Some("baz"), // 6 - passes both filters (25 > 10) + Some("foo"), // 7 - passes c1 filter, fails c2 filter (7 <= 10) + Some("bar"), // 8 - fails c1 filter + Some("qux"), // 9 - passes both filters (30 > 10) + ])); + + let c2: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(5), + Some(15), + Some(8), + Some(20), + Some(12), + Some(9), + Some(25), + Some(7), + Some(18), + Some(30), + ])); + + let batch = create_batch(vec![("c1", c1), ("c2", c2)]); + + // Create filter: c1 != 'bar' AND c2 > 10 + // + // First predicate (c1 != 'bar'): + // - Rows passing: 0, 3, 4, 6, 7, 9 (6 rows) + // - Rows pruned: 1, 2, 5, 8 (4 rows) + // + // Second predicate (c2 > 10) on remaining 6 rows: + // - Rows passing: 3, 4, 6, 9 (4 rows with c2 = 20, 12, 25, 30) + // - Rows pruned: 0, 7 (2 rows with c2 = 5, 7) + // + // Expected final metrics: + // - pushdown_rows_matched: 4 (final result) + // - pushdown_rows_pruned: 4 + 2 = 6 (cumulative) + // - Total: 4 + 6 = 10 + + let filter = col("c1").not_eq(lit("bar")).and(col("c2").gt(lit(10))); + + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .round_trip(vec![batch]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + // Verify the result rows + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | baz | 20 | + | foo | 12 | + | baz | 25 | + | qux | 30 | + +-----+----+ + "); + + // Verify metrics - this is the key test + let pushdown_rows_matched = get_value(&metrics, "pushdown_rows_matched"); + let pushdown_rows_pruned = get_value(&metrics, "pushdown_rows_pruned"); + + assert_eq!( + pushdown_rows_matched, 4, + "Expected 4 rows to pass both predicates" + ); + assert_eq!( + pushdown_rows_pruned, 6, + "Expected 6 rows to be pruned (4 by first predicate + 2 by second predicate)" + ); + + // The sum should equal the total number of rows + assert_eq!( + pushdown_rows_matched + pushdown_rows_pruned, + 10, + "matched + pruned should equal total rows" + ); + } + #[tokio::test] async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) @@ -2119,13 +2210,13 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); Ok(()) } @@ -2148,13 +2239,13 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); Ok(()) } @@ -2279,11 +2370,11 @@ mod tests { let size_hint_calls = reader_factory.metadata_size_hint_calls.clone(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(Arc::clone(&schema)) .with_parquet_file_reader_factory(reader_factory) .with_metadata_size_hint(456), ); - let config = FileScanConfigBuilder::new(store_url, schema, source) + let config = FileScanConfigBuilder::new(store_url, source) .with_file( PartitionedFile { object_meta: ObjectMeta { diff --git a/datafusion/core/src/datasource/view_test.rs b/datafusion/core/src/datasource/view_test.rs index 85ad9ff664ade..35418d6dea632 100644 --- a/datafusion/core/src/datasource/view_test.rs +++ b/datafusion/core/src/datasource/view_test.rs @@ -46,13 +46,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---+ | b | +---+ | 2 | +---+ - "###); + "); Ok(()) } @@ -96,14 +96,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 1 | 2 | 3 | | 4 | 5 | 6 | +---------+---------+---------+ - "###); + "); let view_sql = "CREATE VIEW replace_xyz AS SELECT * REPLACE (column1*2 as column1) FROM xyz"; @@ -115,14 +115,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 2 | 2 | 3 | | 8 | 5 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -146,14 +146,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+ | column1_alias | +---------------+ | 1 | | 4 | +---------------+ - "###); + "); Ok(()) } @@ -177,14 +177,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+---------------+ | column2_alias | column1_alias | +---------------+---------------+ | 2 | 1 | | 5 | 4 | +---------------+---------------+ - "###); + "); Ok(()) } @@ -213,14 +213,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } @@ -249,13 +249,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 4 | +---------+ - "###); + "); Ok(()) } @@ -287,14 +287,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column2 | column1 | column3 | +---------+---------+---------+ | 2 | 1 | 3 | | 5 | 4 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -358,7 +358,10 @@ mod tests { .to_string(); assert!(formatted.contains("DataSourceExec: ")); assert!(formatted.contains("file_type=parquet")); - assert!(formatted.contains("projection=[bool_col, int_col], limit=10")); + assert!( + formatted.contains("projection=[bool_col, int_col], limit=10"), + "{formatted}" + ); Ok(()) } @@ -442,14 +445,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 687779787ab50..a769bb01b4354 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::{Arc, Weak}; +use std::time::Duration; use super::options::ReadOptions; use crate::datasource::dynamic_file::DynamicListTableFactory; @@ -33,20 +34,20 @@ use crate::{ datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }, - datasource::{provider_as_source, MemTable, ViewTable}, + datasource::{MemTable, ViewTable, provider_as_source}, error::Result, execution::{ + FunctionRegistry, options::ArrowReadOptions, runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, - FunctionRegistry, }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, - TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, ResetVariable, + SetVariable, TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -58,32 +59,43 @@ pub use crate::execution::session_state::SessionState; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::MemoryCatalogProvider; +use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::{ DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory, }; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigField, ConfigOptions}; use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::{ + DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, config::{ConfigExtension, TableOptions}, exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, +}; +pub use datafusion_execution::TaskContext; +use datafusion_execution::cache::cache_manager::{ + DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT, DEFAULT_LIST_FILES_CACHE_TTL, + DEFAULT_METADATA_CACHE_LIMIT, }; pub use datafusion_execution::config::SessionConfig; +use datafusion_execution::disk_manager::{ + DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder, +}; use datafusion_execution::registry::SerializerRegistry; -pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +#[cfg(feature = "sql")] +use datafusion_expr::planner::RelationPlanner; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ + Expr, UserDefinedLogicalNode, WindowUDF, expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, }; -use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; use datafusion_optimizer::Analyzer; +use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; use datafusion_session::SessionStore; @@ -476,6 +488,11 @@ impl SessionContext { self.state.write().append_optimizer_rule(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.state.write().remove_optimizer_rule(name) + } + /// Adds an analyzer rule to the end of the existing rules. /// /// See [`SessionState`] for more control of when the rule is applied. @@ -678,7 +695,7 @@ impl SessionContext { match ddl { DdlStatement::CreateExternalTable(cmd) => { (Box::pin(async move { self.create_external_table(&cmd).await }) - as std::pin::Pin + Send>>) + as std::pin::Pin + Send>>) .await } DdlStatement::CreateMemoryTable(cmd) => { @@ -709,7 +726,12 @@ impl SessionContext { } // TODO what about the other statements (like TransactionStart and TransactionEnd) LogicalPlan::Statement(Statement::SetVariable(stmt)) => { - self.set_variable(stmt).await + self.set_variable(stmt).await?; + self.return_empty_dataframe() + } + LogicalPlan::Statement(Statement::ResetVariable(stmt)) => { + self.reset_variable(stmt).await?; + self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Prepare(Prepare { name, @@ -774,7 +796,7 @@ impl SessionContext { /// * [`SessionState::create_physical_expr`] for a lower level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs pub fn create_physical_expr( &self, expr: Expr, @@ -1052,22 +1074,22 @@ impl SessionContext { } else if allow_missing { return self.return_empty_dataframe(); } else { - return self.schema_doesnt_exist_err(name); + return self.schema_doesnt_exist_err(&name); } }; let dereg = catalog.deregister_schema(name.schema_name(), cascade)?; match (dereg, allow_missing) { (None, true) => self.return_empty_dataframe(), - (None, false) => self.schema_doesnt_exist_err(name), + (None, false) => self.schema_doesnt_exist_err(&name), (Some(_), _) => self.return_empty_dataframe(), } } - fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result { + fn schema_doesnt_exist_err(&self, schemaref: &SchemaReference) -> Result { exec_err!("Schema '{schemaref}' doesn't exist.") } - async fn set_variable(&self, stmt: SetVariable) -> Result { + async fn set_variable(&self, stmt: SetVariable) -> Result<()> { let SetVariable { variable, value, .. } = stmt; @@ -1097,11 +1119,37 @@ impl SessionContext { for udf in udfs_to_update { state.register_udf(udf)?; } + } - drop(state); + Ok(()) + } + + async fn reset_variable(&self, stmt: ResetVariable) -> Result<()> { + let variable = stmt.variable; + if variable.starts_with("datafusion.runtime.") { + return self.reset_runtime_variable(&variable); } - self.return_empty_dataframe() + let mut state = self.state.write(); + state.config_mut().options_mut().reset(&variable)?; + + // Refresh UDFs to ensure configuration-dependent behavior updates + let config_options = state.config().options(); + let udfs_to_update: Vec<_> = state + .scalar_functions() + .values() + .filter_map(|udf| { + udf.inner() + .with_updated_config(config_options) + .map(Arc::new) + }) + .collect(); + + for udf in udfs_to_update { + state.register_udf(udf)?; + } + + Ok(()) } fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { @@ -1124,6 +1172,53 @@ impl SessionContext { let limit = Self::parse_memory_limit(value)?; builder.with_metadata_cache_limit(limit) } + "list_files_cache_limit" => { + let limit = Self::parse_memory_limit(value)?; + builder.with_object_list_cache_limit(limit) + } + "list_files_cache_ttl" => { + let duration = Self::parse_duration(value)?; + builder.with_object_list_cache_ttl(Some(duration)) + } + _ => return plan_err!("Unknown runtime configuration: {variable}"), + // Remember to update `reset_runtime_variable()` when adding new options + }; + + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + + Ok(()) + } + + fn reset_runtime_variable(&self, variable: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + let mut state = self.state.write(); + + let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + match key { + "memory_limit" => { + builder.memory_pool = None; + } + "max_temp_directory_size" => { + builder = + builder.with_max_temp_directory_size(DEFAULT_MAX_TEMP_DIRECTORY_SIZE); + } + "temp_directory" => { + builder.disk_manager_builder = Some(DiskManagerBuilder::default()); + } + "metadata_cache_limit" => { + builder = builder.with_metadata_cache_limit(DEFAULT_METADATA_CACHE_LIMIT); + } + "list_files_cache_limit" => { + builder = builder + .with_object_list_cache_limit(DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT); + } + "list_files_cache_ttl" => { + builder = + builder.with_object_list_cache_ttl(DEFAULT_LIST_FILES_CACHE_TTL); + } _ => return plan_err!("Unknown runtime configuration: {variable}"), }; @@ -1164,6 +1259,36 @@ impl SessionContext { } } + fn parse_duration(duration: &str) -> Result { + let mut minutes = None; + let mut seconds = None; + + for duration in duration.split_inclusive(&['m', 's']) { + let (number, unit) = duration.split_at(duration.len() - 1); + let number: u64 = number.parse().map_err(|_| { + plan_datafusion_err!("Failed to parse number from duration '{duration}'") + })?; + + match unit { + "m" if minutes.is_none() && seconds.is_none() => minutes = Some(number), + "s" if seconds.is_none() => seconds = Some(number), + _ => plan_err!( + "Invalid duration, unit must be either 'm' (minutes), or 's' (seconds), and be in the correct order" + )?, + } + } + + let duration = Duration::from_secs( + minutes.unwrap_or_default() * 60 + seconds.unwrap_or_default(), + ); + + if duration.is_zero() { + return plan_err!("Duration must be greater than 0 seconds"); + } + + Ok(duration) + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1197,13 +1322,12 @@ impl SessionContext { .and_then(|c| c.schema(&resolved.schema)) }; - if let Some(schema) = maybe_schema { - if let Some(table_provider) = schema.table(&table).await? { - if table_provider.table_type() == table_type { - schema.deregister_table(&table)?; - return Ok(true); - } - } + if let Some(schema) = maybe_schema + && let Some(table_provider) = schema.table(&table).await? + && table_provider.table_type() == table_type + { + schema.deregister_table(&table)?; + return Ok(true); } Ok(false) @@ -1219,7 +1343,7 @@ impl SessionContext { _ => { return Err(DataFusionError::Configuration( "Function factory has not been configured".to_string(), - )) + )); } } }; @@ -1269,14 +1393,18 @@ impl SessionContext { exec_datafusion_err!("Prepared statement '{}' does not exist", name) })?; + let state = self.state.read(); + let context = SimplifyContext::new(state.execution_props()); + let simplifier = ExprSimplifier::new(context); + // Only allow literals as parameters for now. let mut params: Vec = parameters .into_iter() - .map(|e| match e { + .map(|e| match simplifier.simplify(e)? { Expr::Literal(scalar, metadata) => { Ok(ScalarAndMetadata::new(scalar, metadata)) } - _ => not_impl_err!("Unsupported parameter type: {}", e), + e => not_impl_err!("Unsupported parameter type: {e}"), }) .collect::>()?; @@ -1359,6 +1487,18 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL table-factor planning. + /// + /// Planners are invoked in reverse registration order, allowing newer + /// planners to take precedence over existing ones. + pub fn register_relation_planner( + &self, + planner: Arc, + ) -> Result<()> { + self.state.write().register_relation_planner(planner) + } + /// Deregisters a UDF within this context. pub fn deregister_udf(&self, name: &str) { self.state.write().deregister_udf(name).ok(); @@ -1788,6 +1928,12 @@ impl FunctionRegistry for SessionContext { } } +impl datafusion_execution::TaskContextProvider for SessionContext { + fn task_ctx(&self) -> Arc { + SessionContext::task_ctx(self) + } +} + /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { @@ -1831,7 +1977,7 @@ pub trait QueryPlanner: Debug { /// because the implementation and requirements vary widely. Please see /// [function_factory example] for a reference implementation. /// -/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/function_factory.rs +/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/function_factory.rs /// /// # Examples of syntax that can be supported /// @@ -2531,4 +2677,69 @@ mod tests { } } } + + #[tokio::test] + async fn remove_optimizer_rule() -> Result<()> { + let get_optimizer_rules = |ctx: &SessionContext| { + ctx.state() + .optimizer() + .rules + .iter() + .map(|r| r.name().to_owned()) + .collect::>() + }; + + let ctx = SessionContext::new(); + assert!(get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // default plan + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(2) AS Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + assert!(ctx.remove_optimizer_rule("simplify_expressions")); + assert!(!get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // plan without the simplify_expressions rule + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + // attempting to remove a non-existing rule returns false + assert!(!ctx.remove_optimizer_rule("simplify_expressions")); + + Ok(()) + } + + #[test] + fn test_parse_duration() { + // Valid durations + for (duration, want) in [ + ("1s", Duration::from_secs(1)), + ("1m", Duration::from_secs(60)), + ("1m0s", Duration::from_secs(60)), + ("1m1s", Duration::from_secs(61)), + ] { + let have = SessionContext::parse_duration(duration).unwrap(); + assert_eq!(want, have); + } + + // Invalid durations + for duration in ["0s", "0m", "1s0m", "1s1m"] { + let have = SessionContext::parse_duration(duration); + assert!(have.is_err()); + } + } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 731f7e59ecfaf..823dc946ea732 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -113,7 +113,7 @@ mod tests { }; use datafusion_execution::config::SessionConfig; - use tempfile::{tempdir, TempDir}; + use tempfile::{TempDir, tempdir}; #[tokio::test] async fn read_with_glob_path() -> Result<()> { @@ -355,7 +355,9 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") + format!( + "Execution error: File path '{expected_path}' does not match the expected extension '.parquet'" + ) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..6a9ebcdf51250 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -27,14 +27,14 @@ use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory}; use crate::datasource::file_format::FileFormatFactory; #[cfg(feature = "sql")] use crate::datasource::provider_as_source; -use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; +use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use arrow_schema::{DataType, FieldRef}; +use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::information_schema::{ - InformationSchemaProvider, INFORMATION_SCHEMA, + INFORMATION_SCHEMA, InformationSchemaProvider, }; -use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; #[cfg(feature = "sql")] @@ -43,21 +43,21 @@ use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, - ResolvedTableReference, TableReference, + DFSchema, DataFusionError, ResolvedTableReference, TableReference, config_err, + exec_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; +#[cfg(feature = "sql")] +use datafusion_expr::TableSource; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] -use datafusion_expr::planner::TypePlanner; +use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; -#[cfg(feature = "sql")] -use datafusion_expr::TableSource; use datafusion_expr::{ AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, }; @@ -67,8 +67,8 @@ use datafusion_optimizer::{ }; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; #[cfg(feature = "sql")] @@ -139,6 +139,8 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + #[cfg(feature = "sql")] + relation_planners: Vec>, /// Provides support for customizing the SQL type planning #[cfg(feature = "sql")] type_planner: Option>, @@ -185,6 +187,7 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + cache_factory: Option>, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. prepared_plans: HashMap>, @@ -206,8 +209,12 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); + #[cfg(feature = "sql")] + let ret = ret.field("relation_planners", &self.relation_planners); + #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -345,6 +352,13 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub(crate) fn remove_optimizer_rule(&mut self, name: &str) -> bool { + let original_len = self.optimizer.rules.len(); + self.optimizer.rules.retain(|r| r.name() != name); + self.optimizer.rules.len() < original_len + } + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn set_function_factory(&mut self, function_factory: Arc) { self.function_factory = Some(function_factory); @@ -355,6 +369,16 @@ impl SessionState { self.function_factory.as_ref() } + /// Register a [`CacheFactory`] for custom caching strategy + pub fn set_cache_factory(&mut self, cache_factory: Arc) { + self.cache_factory = Some(cache_factory); + } + + /// Get the cache factory + pub fn cache_factory(&self) -> Option<&Arc> { + self.cache_factory.as_ref() + } + /// Get the table factories pub fn table_factories(&self) -> &HashMap> { &self.table_factories @@ -480,10 +504,10 @@ impl SessionState { let resolved = self.resolve_table_ref(reference); if let Entry::Vacant(v) = provider.tables.entry(resolved) { let resolved = v.key(); - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { - if let Some(table) = schema.table(&resolved.table).await? { - v.insert(provider_as_source(table)); - } + if let Ok(schema) = self.schema_for_ref(resolved.clone()) + && let Some(table) = schema.table(&resolved.table).await? + { + v.insert(provider_as_source(table)); } } } @@ -547,6 +571,16 @@ impl SessionState { let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?; + self.create_logical_expr_from_sql_expr(sql_expr, df_schema) + } + + /// Creates a datafusion style AST [`Expr`] from a SQL expression. + #[cfg(feature = "sql")] + pub fn create_logical_expr_from_sql_expr( + &self, + sql_expr: SQLExprWithAlias, + df_schema: &DFSchema, + ) -> datafusion_common::Result { let provider = SessionContextProvider { state: self, tables: HashMap::new(), @@ -571,6 +605,24 @@ impl SessionState { &self.expr_planners } + #[cfg(feature = "sql")] + /// Returns the registered relation planners in priority order. + pub fn relation_planners(&self) -> &[Arc] { + &self.relation_planners + } + + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL relation planning. + /// + /// Newly registered planners are given higher priority than existing ones. + pub fn register_relation_planner( + &mut self, + planner: Arc, + ) -> datafusion_common::Result<()> { + self.relation_planners.insert(0, planner); + Ok(()) + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -685,7 +737,7 @@ impl SessionState { /// * [`create_physical_expr`] for a lower-level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs /// [`SessionContext::create_physical_expr`]: crate::execution::context::SessionContext::create_physical_expr pub fn create_physical_expr( &self, @@ -788,10 +840,18 @@ impl SessionState { overwrite: bool, ) -> Result<(), DataFusionError> { let ext = file_format.get_ext().to_lowercase(); - match (self.file_formats.entry(ext.clone()), overwrite){ - (Entry::Vacant(e), _) => {e.insert(file_format);}, - (Entry::Occupied(mut e), true) => {e.insert(file_format);}, - (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + match (self.file_formats.entry(ext.clone()), overwrite) { + (Entry::Vacant(e), _) => { + e.insert(file_format); + } + (Entry::Occupied(mut e), true) => { + e.insert(file_format); + } + (Entry::Occupied(_), false) => { + return config_err!( + "File type already registered for extension {ext}. Set overwrite to true to replace this extension." + ); + } }; Ok(()) } @@ -914,6 +974,8 @@ pub struct SessionStateBuilder { analyzer: Option, expr_planners: Option>>, #[cfg(feature = "sql")] + relation_planners: Option>>, + #[cfg(feature = "sql")] type_planner: Option>, optimizer: Option, physical_optimizers: Option, @@ -931,6 +993,7 @@ pub struct SessionStateBuilder { table_factories: Option>>, runtime_env: Option>, function_factory: Option>, + cache_factory: Option>, // fields to support convenience functions analyzer_rules: Option>>, optimizer_rules: Option>>, @@ -951,6 +1014,8 @@ impl SessionStateBuilder { analyzer: None, expr_planners: None, #[cfg(feature = "sql")] + relation_planners: None, + #[cfg(feature = "sql")] type_planner: None, optimizer: None, physical_optimizers: None, @@ -968,6 +1033,7 @@ impl SessionStateBuilder { table_factories: None, runtime_env: None, function_factory: None, + cache_factory: None, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1001,6 +1067,8 @@ impl SessionStateBuilder { analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), #[cfg(feature = "sql")] + relation_planners: Some(existing.relation_planners), + #[cfg(feature = "sql")] type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), @@ -1020,7 +1088,7 @@ impl SessionStateBuilder { table_factories: Some(existing.table_factories), runtime_env: Some(existing.runtime_env), function_factory: existing.function_factory, - + cache_factory: existing.cache_factory, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1141,6 +1209,16 @@ impl SessionStateBuilder { self } + #[cfg(feature = "sql")] + /// Sets the [`RelationPlanner`]s used to customize SQL relation planning. + pub fn with_relation_planners( + mut self, + relation_planners: Vec>, + ) -> Self { + self.relation_planners = Some(relation_planners); + self + } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. #[cfg(feature = "sql")] pub fn with_type_planner(mut self, type_planner: Arc) -> Self { @@ -1309,6 +1387,15 @@ impl SessionStateBuilder { self } + /// Set a [`CacheFactory`] for custom caching strategy + pub fn with_cache_factory( + mut self, + cache_factory: Option>, + ) -> Self { + self.cache_factory = cache_factory; + self + } + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] /// for more details. /// @@ -1355,6 +1442,8 @@ impl SessionStateBuilder { analyzer, expr_planners, #[cfg(feature = "sql")] + relation_planners, + #[cfg(feature = "sql")] type_planner, optimizer, physical_optimizers, @@ -1372,6 +1461,7 @@ impl SessionStateBuilder { table_factories, runtime_env, function_factory, + cache_factory, analyzer_rules, optimizer_rules, physical_optimizer_rules, @@ -1385,6 +1475,8 @@ impl SessionStateBuilder { analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), #[cfg(feature = "sql")] + relation_planners: relation_planners.unwrap_or_default(), + #[cfg(feature = "sql")] type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), @@ -1408,6 +1500,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + cache_factory, prepared_plans: HashMap::new(), }; @@ -1521,6 +1614,12 @@ impl SessionStateBuilder { &mut self.expr_planners } + #[cfg(feature = "sql")] + /// Returns a mutable reference to the current [`RelationPlanner`] list. + pub fn relation_planners(&mut self) -> &mut Option>> { + &mut self.relation_planners + } + /// Returns the current type_planner value #[cfg(feature = "sql")] pub fn type_planner(&mut self) -> &mut Option> { @@ -1611,6 +1710,11 @@ impl SessionStateBuilder { &mut self.function_factory } + /// Returns the cache factory + pub fn cache_factory(&mut self) -> &mut Option> { + &mut self.cache_factory + } + /// Returns the current analyzer_rules value pub fn analyzer_rules( &mut self, @@ -1649,6 +1753,7 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -1695,6 +1800,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.expr_planners() } + fn get_relation_planners(&self) -> &[Arc] { + self.state.relation_planners() + } + fn get_type_planner(&self) -> Option> { if let Some(type_planner) = &self.state.type_planner { Some(Arc::clone(type_planner)) @@ -1764,7 +1873,7 @@ impl ContextProvider for SessionContextProvider<'_> { } fn get_variable_type(&self, variable_names: &[String]) -> Option { - use datafusion_expr::var_provider::{is_system_variables, VarType}; + use datafusion_expr::var_provider::{VarType, is_system_variables}; if variable_names.is_empty() { return None; @@ -1947,6 +2056,12 @@ impl FunctionRegistry for SessionState { } } +impl datafusion_execution::TaskContextProvider for SessionState { + fn task_ctx(&self) -> Arc { + SessionState::task_ctx(self) + } +} + impl OptimizerConfig for SessionState { fn query_execution_start_time(&self) -> DateTime { self.execution_props.query_execution_start_time @@ -2037,14 +2152,27 @@ pub(crate) struct PreparedPlan { pub(crate) plan: Arc, } +/// A [`CacheFactory`] can be registered via [`SessionState`] +/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`]. +/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`] +/// may need to be implemented to handle such plans. +pub trait CacheFactory: Debug + Send + Sync { + /// Create a logical plan for caching + fn create( + &self, + plan: LogicalPlan, + session_state: &SessionState, + ) -> datafusion_common::Result; +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; use crate::common::assert_contains; use crate::config::ConfigOptions; + use crate::datasource::MemTable; use crate::datasource::empty::EmptyTable; use crate::datasource::provider_as_source; - use crate::datasource::MemTable; use crate::execution::context::SessionState; use crate::logical_expr::planner::ExprPlanner; use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; @@ -2054,13 +2182,13 @@ mod tests { use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; - use datafusion_common::config::Dialect; use datafusion_common::DFSchema; use datafusion_common::Result; + use datafusion_common::config::Dialect; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; - use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; @@ -2097,6 +2225,36 @@ mod tests { assert!(sql_to_expr(&state).is_err()) } + #[test] + #[cfg(feature = "sql")] + fn test_create_logical_expr_from_sql_expr() { + let state = SessionStateBuilder::new().with_default_features().build(); + + let provider = SessionContextProvider { + state: &state, + tables: HashMap::new(), + }; + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + let dialect = state.config.options().sql_parser.dialect; + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + + for sql in ["[1,2,3]", "a > 10", "SUM(a)"] { + let sql_expr = state.sql_to_expr(sql, &dialect).unwrap(); + let from_str = query + .sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + .unwrap(); + + let sql_expr_with_alias = + state.sql_to_expr_with_alias(sql, &dialect).unwrap(); + let from_expr = state + .create_logical_expr_from_sql_expr(sql_expr_with_alias, &df_schema) + .unwrap(); + assert_eq!(from_str, from_expr); + } + } + #[test] fn test_from_existing() -> Result<()> { fn employee_batch() -> RecordBatch { @@ -2137,13 +2295,15 @@ mod tests { .table_exist("employee"); assert!(is_exist); let new_state = SessionStateBuilder::new_from_existing(session_state).build(); - assert!(new_state - .catalog_list() - .catalog(default_catalog.as_str()) - .unwrap() - .schema(default_schema.as_str()) - .unwrap() - .table_exist("employee")); + assert!( + new_state + .catalog_list() + .catalog(default_catalog.as_str()) + .unwrap() + .schema(default_schema.as_str()) + .unwrap() + .table_exist("employee") + ); // if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema let disable_create_default = @@ -2151,10 +2311,12 @@ mod tests { let without_default_state = SessionStateBuilder::new() .with_config(disable_create_default) .build(); - assert!(without_default_state - .catalog_list() - .catalog(&default_catalog) - .is_none()); + assert!( + without_default_state + .catalog_list() + .catalog(&default_catalog) + .is_none() + ); let new_state = SessionStateBuilder::new_from_existing(without_default_state).build(); assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 62a575541a5d8..721710d4e057e 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -17,6 +17,7 @@ use crate::catalog::listing_schema::ListingSchemaProvider; use crate::catalog::{CatalogProvider, TableProviderFactory}; +use crate::datasource::file_format::FileFormatFactory; use crate::datasource::file_format::arrow::ArrowFormatFactory; #[cfg(feature = "avro")] use crate::datasource::file_format::avro::AvroFormatFactory; @@ -24,7 +25,6 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::json::JsonFormatFactory; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormatFactory; -use crate::datasource::file_format::FileFormatFactory; use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] @@ -103,7 +103,7 @@ impl SessionStateDefaults { /// returns the list of default [`ScalarUDF`]s pub fn default_scalar_functions() -> Vec> { - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_mut))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_mut))] let mut functions: Vec> = functions::all_default_functions(); #[cfg(feature = "nested_expressions")] @@ -155,7 +155,7 @@ impl SessionStateDefaults { } /// registers all the builtin array functions - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_variables))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_variables))] pub fn register_array_functions(state: &mut SessionState) { // register crate of array expressions (if enabled) #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 381dd5e9e8482..e83934a8e281d 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#![deny(clippy::allow_attributes)] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" @@ -35,6 +36,9 @@ ) )] #![warn(missing_docs, clippy::needless_borrow)] +// Use `allow` instead of `expect` for test configuration to explicitly +// disable the lint for all test code rather than expecting violations +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! [DataFusion] is an extensible query engine written in Rust that //! uses [Apache Arrow] as its in-memory format. DataFusion's target users are @@ -358,7 +362,7 @@ //! [`TreeNode`]: datafusion_common::tree_node::TreeNode //! [`tree_node module`]: datafusion_expr::logical_plan::tree_node //! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier -//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs //! //! ### Physical Plans //! @@ -647,7 +651,7 @@ //! //! [Tokio]: https://tokio.rs //! [`Runtime`]: tokio::runtime::Runtime -//! [thread_pools example]: https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/thread_pools.rs +//! [thread_pools example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/thread_pools.rs //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c280b50a9f07a..9eaf1403e5757 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::{source_as_provider, DefaultTableSource}; +use crate::datasource::{DefaultTableSource, source_as_provider}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -52,29 +52,32 @@ use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, - Partitioning, PhysicalExpr, WindowExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, displayable, windows, }; use crate::schema_equivalence::schema_satisfied_by; -use arrow::array::{builder::StringBuilder, RecordBatch}; +use arrow::array::{RecordBatch, builder::StringBuilder}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; +use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeLevel; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::TableReference; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - ScalarValue, + DFSchema, ScalarValue, exec_err, internal_datafusion_err, internal_err, not_impl_err, + plan_err, +}; +use datafusion_common::{ + TableReference, assert_eq_or_internal_err, assert_or_internal_err, }; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, - NullTreatment, WindowFunction, WindowFunctionParams, + AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, + WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -87,7 +90,7 @@ use datafusion_expr::{ use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::{ - create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, + LexOrdering, PhysicalSortExpr, create_physical_sort_exprs, }; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; @@ -101,7 +104,7 @@ use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; -use itertools::{multiunzip, Itertools}; +use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; @@ -347,11 +350,11 @@ impl DefaultPhysicalPlanner { .flatten() .collect::>(); // Ideally this never happens if we have a valid LogicalPlan tree - if outputs.len() != 1 { - return internal_err!( - "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" - ); - } + assert_eq_or_internal_err!( + outputs.len(), + 1, + "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" + ); let plan = outputs.pop().unwrap(); Ok(plan) } @@ -496,7 +499,7 @@ impl DefaultPhysicalPlanner { output_schema, }) => { let output_schema = Arc::clone(output_schema.inner()); - self.plan_describe(Arc::clone(schema), output_schema)? + self.plan_describe(&Arc::clone(schema), output_schema)? } // 1 Child @@ -525,12 +528,22 @@ impl DefaultPhysicalPlanner { let keep_partition_by_columns = match source_option_tuples .get("execution.keep_partition_by_columns") - .map(|v| v.trim()) { - None => session_state.config().options().execution.keep_partition_by_columns, + .map(|v| v.trim()) + { + None => { + session_state + .config() + .options() + .execution + .keep_partition_by_columns + } Some("true") => true, Some("false") => false, - Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\"" + ))); + } }; let sink_format = file_type_to_format(file_type)? @@ -588,17 +601,18 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::Window(Window { window_expr, .. }) => { - if window_expr.is_empty() { - return internal_err!("Impossibly got empty window expression"); - } + assert_or_internal_err!( + !window_expr.is_empty(), + "Impossibly got empty window expression" + ); let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { Expr::WindowFunction(window_fun) => { let WindowFunctionParams { - ref partition_by, - ref order_by, + partition_by, + order_by, .. } = &window_fun.as_ref().params; generate_sort_key(partition_by, order_by) @@ -608,8 +622,8 @@ impl DefaultPhysicalPlanner { match &**expr { Expr::WindowFunction(window_fun) => { let WindowFunctionParams { - ref partition_by, - ref order_by, + partition_by, + order_by, .. } = &window_fun.as_ref().params; generate_sort_key(partition_by, order_by) @@ -622,11 +636,11 @@ impl DefaultPhysicalPlanner { let sort_keys = get_sort_keys(&window_expr[0])?; if window_expr.len() > 1 { debug_assert!( - window_expr[1..] - .iter() - .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), - "all window expressions shall have the same sort keys, as guaranteed by logical planning" - ); + window_expr[1..] + .iter() + .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), + "all window expressions shall have the same sort keys, as guaranteed by logical planning" + ); } let logical_schema = node.schema(); @@ -683,6 +697,17 @@ impl DefaultPhysicalPlanner { ) { let mut differences = Vec::new(); + + if physical_input_schema.metadata() + != physical_input_schema_from_logical.metadata() + { + differences.push(format!( + "schema metadata differs: (physical) {:?} vs (logical) {:?}", + physical_input_schema.metadata(), + physical_input_schema_from_logical.metadata() + )); + } + if physical_input_schema.fields().len() != physical_input_schema_from_logical.fields().len() { @@ -712,11 +737,20 @@ impl DefaultPhysicalPlanner { if physical_field.is_nullable() && !logical_field.is_nullable() { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } + if physical_field.metadata() != logical_field.metadata() { + differences.push(format!( + "field metadata at index {} [{}]: (physical) {:?} vs (logical) {:?}", + i, + physical_field.name(), + physical_field.metadata(), + logical_field.metadata() + )); + } } - return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences - .iter() - .map(|s| format!("\n\t- {s}")) - .join("")); + return internal_err!( + "Physical input schema should be the same as the one converted from logical input schema. Differences: {}", + differences.iter().map(|s| format!("\n\t- {s}")).join("") + ); } let groups = self.create_grouping_physical_expr( @@ -776,7 +810,7 @@ impl DefaultPhysicalPlanner { _ => { return internal_err!( "Unexpected result from try_plan_async_exprs" - ) + ); } } } @@ -850,6 +884,7 @@ impl DefaultPhysicalPlanner { )? { PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + .with_batch_size(session_state.config().batch_size())? } PlanAsyncExpr::Async( async_map, @@ -868,11 +903,12 @@ impl DefaultPhysicalPlanner { .with_projection(Some( (0..input.schema().fields().len()).collect(), ))? + .with_batch_size(session_state.config().batch_size())? } _ => { return internal_err!( "Unexpected result from try_plan_async_exprs" - ) + ); } }; @@ -1207,7 +1243,7 @@ impl DefaultPhysicalPlanner { let filter_df_fields = filter_df_fields .into_iter() .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) + (qualifier.cloned(), Arc::clone(field)) }) .collect(); @@ -1463,19 +1499,24 @@ impl DefaultPhysicalPlanner { } let plan = match maybe_plan { - Some(v) => Ok(v), - _ => plan_err!("No installed planner was able to convert the custom node to an execution plan: {:?}", node) - }?; + Some(v) => Ok(v), + _ => plan_err!( + "No installed planner was able to convert the custom node to an execution plan: {:?}", + node + ), + }?; // Ensure the ExecutionPlan's schema matches the // declared logical schema to catch and warn about // logic errors when creating user defined plans. if !node.schema().matches_arrow_schema(&plan.schema()) { return plan_err!( - "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ + "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", - node, node.schema(), plan.schema() - ); + node, + node.schema(), + plan.schema() + ); } else { plan } @@ -1502,17 +1543,17 @@ impl DefaultPhysicalPlanner { LogicalPlan::Explain(_) => { return internal_err!( "Unsupported logical plan: Explain must be root of the plan" - ) + ); } LogicalPlan::Distinct(_) => { return internal_err!( "Unsupported logical plan: Distinct should be replaced to Aggregate" - ) + ); } LogicalPlan::Analyze(_) => { return internal_err!( "Unsupported logical plan: Analyze must be root of the plan" - ) + ); } }; Ok(exec_node) @@ -1556,7 +1597,8 @@ impl DefaultPhysicalPlanner { } } else if group_expr.is_empty() { // No GROUP BY clause - create empty PhysicalGroupBy - Ok(PhysicalGroupBy::new(vec![], vec![], vec![])) + // no expressions, no null expressions and no grouping expressions + Ok(PhysicalGroupBy::new(vec![], vec![], vec![], false)) } else { Ok(PhysicalGroupBy::new_single( group_expr @@ -1628,6 +1670,7 @@ fn merge_grouping_set_physical_expr( grouping_set_expr, null_exprs, merged_sets, + true, )) } @@ -1670,7 +1713,7 @@ fn create_cube_physical_expr( } } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS @@ -1715,7 +1758,7 @@ fn create_rollup_physical_expr( groups.push(group) } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// For a given logical expr, get a properly typed NULL ScalarValue physical expression @@ -1752,11 +1795,11 @@ fn qualify_join_schema_sides( let join_fields = join_schema.fields(); // Validate lengths - if join_fields.len() != left_fields.len() + right_fields.len() { - return internal_err!( - "Join schema field count must match left and right field count." - ); - } + assert_eq_or_internal_err!( + join_fields.len(), + left_fields.len() + right_fields.len(), + "Join schema field count must match left and right field count." + ); // Validate field names match for (i, (field, expected)) in join_fields @@ -1764,14 +1807,12 @@ fn qualify_join_schema_sides( .zip(left_fields.iter().chain(right_fields.iter())) .enumerate() { - if field.name() != expected.name() { - return internal_err!( - "Field name mismatch at index {}: expected '{}', found '{}'", - i, - expected.name(), - field.name() - ); - } + assert_eq_or_internal_err!( + field.name(), + expected.name(), + "Field name mismatch at index {}", + i + ); } // qualify sides @@ -1858,9 +1899,10 @@ pub fn create_window_expr_with_name( if !is_window_frame_bound_valid(window_frame) { return plan_err!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - window_frame.start_bound, window_frame.end_bound - ); + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, + window_frame.end_bound + ); } let window_frame = Arc::new(window_frame.clone()); @@ -2243,6 +2285,7 @@ impl DefaultPhysicalPlanner { /// Optimize a physical plan by applying each physical optimizer, /// calling observer(plan, optimizer after each one) + #[expect(clippy::needless_pass_by_value)] pub fn optimize_physical_plan( &self, plan: Arc, @@ -2277,7 +2320,7 @@ impl DefaultPhysicalPlanner { // This only checks the schema in release build, and performs additional checks in debug mode. OptimizationInvariantChecker::new(optimizer) - .check(&new_plan, before_schema)?; + .check(&new_plan, &before_schema)?; debug!( "Optimized physical plan by {}:\n{}\n", @@ -2310,7 +2353,7 @@ impl DefaultPhysicalPlanner { // return an record_batch which describes a table's schema. fn plan_describe( &self, - table_schema: Arc, + table_schema: &Arc, output_schema: Arc, ) -> Result> { let mut column_names = StringBuilder::new(); @@ -2513,11 +2556,14 @@ impl<'a> OptimizationInvariantChecker<'a> { pub fn check( &mut self, plan: &Arc, - previous_schema: Arc, + previous_schema: &Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != previous_schema { - internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { + internal_err!( + "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", self.rule.name(), previous_schema, plan.schema() @@ -2532,6 +2578,38 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. This can happen due to physical expressions knowing +/// more about their null-ness than their logical counterparts. +/// This change is allowed because for any field the non-nullable domain `F` is a strict subset +/// of the nullable domain `F ∪ { NULL }`. A physical schema that guarantees a stricter subset +/// of values will not violate any assumptions made based on the less strict schema. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; @@ -2580,11 +2658,11 @@ mod tests { use std::ops::{BitAnd, Not}; use super::*; - use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; + use crate::datasource::file_format::options::CsvReadOptions; use crate::physical_plan::{ - expressions, DisplayAs, DisplayFormatType, PlanProperties, - SendableRecordBatchStream, + DisplayAs, DisplayFormatType, PlanProperties, SendableRecordBatchStream, + expressions, }; use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; @@ -2595,12 +2673,12 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + DFSchemaRef, TableReference, ToDFSchema as _, assert_contains, }; - use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; + use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::builder::subquery_alias; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{LogicalPlanBuilder, UserDefinedLogicalNodeCore, col, lit}; use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; @@ -2773,6 +2851,7 @@ mod tests { true, ], ], + has_grouping_set: true, }, ) "#); @@ -2883,6 +2962,7 @@ mod tests { false, ], ], + has_grouping_set: true, }, ) "#); @@ -3000,8 +3080,7 @@ mod tests { .create_physical_plan(&logical_plan, &session_state) .await; - let expected_error = - "No installed planner was able to convert the custom node to an execution plan: NoOp"; + let expected_error = "No installed planner was able to convert the custom node to an execution plan: NoOp"; match plan { Ok(_) => panic!("Expected planning failure"), Err(e) => assert!( @@ -3067,7 +3146,7 @@ mod tests { assert_contains!( &e, - r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": Boolean), Utf8]"# + r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": non-null Boolean), Utf8]"# ); Ok(()) @@ -3258,18 +3337,27 @@ mod tests { if let Some(plan) = plan.as_any().downcast_ref::() { let stringified_plans = plan.stringified_plans(); assert!(stringified_plans.len() >= 4); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalLogicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::InitialPhysicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::OptimizedPhysicalPlan { .. }))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalPhysicalPlan))); + assert!( + stringified_plans + .iter() + .any(|p| matches!(p.plan_type, PlanType::FinalLogicalPlan)) + ); + assert!( + stringified_plans + .iter() + .any(|p| matches!(p.plan_type, PlanType::InitialPhysicalPlan)) + ); + assert!( + stringified_plans.iter().any(|p| matches!( + p.plan_type, + PlanType::OptimizedPhysicalPlan { .. } + )) + ); + assert!( + stringified_plans + .iter() + .any(|p| matches!(p.plan_type, PlanType::FinalPhysicalPlan)) + ); } else { panic!( "Plan was not an explain plan: {}", @@ -3636,8 +3724,12 @@ digraph { } fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => plan_err!("extension node failed it's user-defined always-invariant check"), - InvariantLevel::Executable => panic!("the OptimizationInvariantChecker should not be checking for executableness"), + InvariantLevel::Always => plan_err!( + "extension node failed it's user-defined always-invariant check" + ), + InvariantLevel::Executable => panic!( + "the OptimizationInvariantChecker should not be checking for executableness" + ), } } fn schema(&self) -> SchemaRef { @@ -3706,24 +3798,26 @@ digraph { // Test: check should pass with same schema let equal_schema = ok_plan.schema(); - OptimizationInvariantChecker::new(&rule).check(&ok_plan, equal_schema)?; + OptimizationInvariantChecker::new(&rule).check(&ok_plan, &equal_schema)?; // Test: should fail with schema changed let different_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&ok_plan, different_schema) + .check(&ok_plan, &different_schema) .unwrap_err(); assert!(expected_err.to_string().contains("PhysicalOptimizer rule 'OptimizerRuleWithSchemaCheck' failed. Schema mismatch. Expected original schema")); // Test: should fail when extension node fails it's own invariant check let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&failing_node, ok_plan.schema()) + .check(&failing_node, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); // Test: should fail when descendent extension node fails let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); @@ -3732,11 +3826,13 @@ digraph { Arc::clone(&child), ])?; let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&invalid_plan, ok_plan.schema()) + .check(&invalid_plan, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); Ok(()) } @@ -3879,4 +3975,229 @@ digraph { Ok(()) } + + // --- Tests for aggregate schema mismatch error messages --- + + use crate::catalog::TableProvider; + use datafusion_catalog::Session; + use datafusion_expr::TableType; + + /// A TableProvider that returns schemas for logical planning vs physical planning. + /// Used to test schema mismatch error messages. + #[derive(Debug)] + struct MockSchemaTableProvider { + logical_schema: SchemaRef, + physical_schema: SchemaRef, + } + + #[async_trait] + impl TableProvider for MockSchemaTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.logical_schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(NoOpExecutionPlan::new(Arc::clone( + &self.physical_schema, + )))) + } + } + + /// Attempts to plan a query with potentially mismatched schemas. + async fn plan_with_schemas( + logical_schema: SchemaRef, + physical_schema: SchemaRef, + query: &str, + ) -> Result> { + let provider = MockSchemaTableProvider { + logical_schema, + physical_schema, + }; + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(provider)).unwrap(); + + ctx.sql(query).await.unwrap().create_physical_plan().await + } + + #[tokio::test] + // When schemas match, planning proceeds past the schema_satisfied_by check. + // It then panics on unimplemented error in NoOpExecutionPlan. + #[should_panic(expected = "NoOpExecutionPlan")] + async fn test_aggregate_schema_check_passes() { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + + plan_with_schemas( + Arc::clone(&schema), + schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new( + Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "schema metadata differs"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_count() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "Different number of fields"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_name() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "different_name", + DataType::Int32, + false, + )])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field name at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_type() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int64, false)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field data type at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_nullability() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field nullability at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field metadata at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_multiple() { + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Utf8, false), + ])); + let physical_schema = Arc::new( + Schema::new(vec![ + Field::new("c1", DataType::Int64, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + Field::new("c2", DataType::Utf8, false), + ]) + .with_metadata(HashMap::from([( + "schema_key".into(), + "schema_value".into(), + )])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + // Verify all applicable error fragments are present + let err_str = err.to_string(); + assert_contains!(&err_str, "schema metadata differs"); + assert_contains!(&err_str, "field data type at index"); + assert_contains!(&err_str, "field nullability at index"); + assert_contains!(&err_str, "field metadata at index"); + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d723620d32323..50e4a2649c923 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -34,10 +34,10 @@ pub use crate::execution::options::{ pub use datafusion_common::Column; pub use datafusion_expr::{ + Expr, expr_fn::*, lit, lit_timestamp_nano, logical_plan::{JoinType, Partitioning}, - Expr, }; pub use datafusion_functions::expr_fn::*; #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 68f83e7f1f115..717182f1d3d5b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -25,9 +25,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; +use crate::datasource::file_format::FileFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::CsvSource; use crate::datasource::{MemTable, TableProvider}; @@ -35,28 +35,31 @@ use crate::error::Result; use crate::logical_expr::LogicalPlan; use crate::test_util::{aggr_test_schema, arrow_test_data}; +use datafusion_common::config::CsvOptions; + use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; #[cfg(feature = "compression")] use datafusion_common::DataFusionError; +use datafusion_datasource::TableSchema; use datafusion_datasource::source::DataSourceExec; -#[cfg(feature = "compression")] -use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; +#[cfg(feature = "compression")] +use bzip2::write::BzEncoder; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource_csv::partitioned_csv_config; #[cfg(feature = "compression")] +use flate2::Compression as GzCompression; +#[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] -use flate2::Compression as GzCompression; +use liblzma::write::XzEncoder; use object_store::local_unpartitioned_file; #[cfg(feature = "compression")] -use xz2::write::XzEncoder; -#[cfg(feature = "compression")] use zstd::Encoder as ZstdEncoder; pub fn create_table_dual() -> Arc { @@ -84,17 +87,26 @@ pub fn scan_partitioned_csv( let schema = aggr_test_schema(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv", arrow_test_data()); + let csv_format: Arc = Arc::new(CsvFormat::default()); + let file_groups = partitioned_file_groups( path.as_str(), filename, partitions, - Arc::new(CsvFormat::default()), + &csv_format, FileCompressionType::UNCOMPRESSED, work_dir, )?; - let source = Arc::new(CsvSource::new(true, b'"', b'"')); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(schema); + let source = Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); let config = - FileScanConfigBuilder::from(partitioned_csv_config(schema, file_groups, source)) + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .build(); Ok(DataSourceExec::from_data_source(config)) @@ -105,7 +117,7 @@ pub fn partitioned_file_groups( path: &str, filename: &str, partitions: usize, - file_format: Arc, + file_format: &Arc, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result> { @@ -189,7 +201,7 @@ pub fn partitioned_file_groups( .collect::>()) } -pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { +pub fn assert_fields_eq(plan: &LogicalPlan, expected: &[&str]) { let actual: Vec = plan .schema() .fields() diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index d31c2719973ec..a0438e3d74ab2 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -20,20 +20,20 @@ use crate::{ execution::{context::SessionState, session_state::SessionStateBuilder}, object_store::{ - memory::InMemory, path::Path, Error, GetOptions, GetResult, ListResult, - MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, - PutPayload, PutResult, + Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + memory::InMemory, path::Path, }, prelude::SessionContext, }; -use futures::{stream::BoxStream, FutureExt}; +use futures::{FutureExt, stream::BoxStream}; use std::{ fmt::{Debug, Display, Formatter}, sync::Arc, }; use tokio::{ sync::Barrier, - time::{timeout, Duration}, + time::{Duration, timeout}, }; use url::Url; diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 7149c5b0bd8ca..466ee38a426fd 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -25,6 +25,7 @@ pub mod csv; use futures::Stream; use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::fs::File; use std::io::Write; use std::path::Path; @@ -36,16 +37,20 @@ use crate::dataframe::DataFrame; use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; +use crate::execution::session_state::CacheFactory; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; -use crate::execution::SendableRecordBatchStream; +use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; -use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_common::{DFSchemaRef, TableReference}; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType, + UserDefinedLogicalNodeCore, +}; use std::pin::Pin; use async_trait::async_trait; @@ -282,3 +287,67 @@ impl RecordBatchStream for BoundedStream { self.record_batch.schema() } } + +#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Ok(Self { + input: inputs[0].clone(), + }) + } +} + +#[derive(Debug)] +struct TestCacheFactory {} + +impl CacheFactory for TestCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(datafusion_expr::Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +/// Create a test table registered to a session context with an associated cache factory +pub async fn test_table_with_cache_factory() -> Result { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(TestCacheFactory {}))) + .build(); + let ctx = SessionContext::new_with_state(session_state); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + ctx.table(name).await +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 203d9e97d2a8c..44e884c23a681 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -32,17 +32,15 @@ use crate::logical_expr::execution_props::ExecutionProps; use crate::logical_expr::simplify::SimplifyContext; use crate::optimizer::simplify_expressions::ExprSimplifier; use crate::physical_expr::create_physical_expr; +use crate::physical_plan::ExecutionPlan; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; -use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; -use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource::TableSchema; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::path::Path; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -157,20 +155,21 @@ impl TestParquetFile { maybe_filter: Option, ) -> Result> { let parquet_options = ctx.copied_table_options().parquet; - let source = Arc::new(ParquetSource::new(parquet_options.clone())); - let scan_config_builder = FileScanConfigBuilder::new( - self.object_store_url.clone(), - Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile { - object_meta: self.object_meta.clone(), - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }); + let source = Arc::new( + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options.clone()), + ); + let scan_config_builder = + FileScanConfigBuilder::new(self.object_store_url.clone(), source).with_file( + PartitionedFile { + object_meta: self.object_meta.clone(), + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + metadata_size_hint: None, + }, + ); let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; @@ -184,10 +183,10 @@ impl TestParquetFile { create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; let source = Arc::new( - ParquetSource::new(parquet_options) + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options) .with_predicate(Arc::clone(&physical_filter_expr)), - ) - .with_schema(TableSchema::from_file_schema(Arc::clone(&self.schema))); + ); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -204,13 +203,12 @@ impl TestParquetFile { /// Recursively searches for DataSourceExec and returns the metrics /// on the first one it finds pub fn parquet_metrics(plan: &Arc) -> Option { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - return data_source_exec.metrics(); - } + { + return data_source_exec.metrics(); } for child in plan.children() { diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index 06ed141b2e8bd..5258f3bf97574 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -116,10 +116,12 @@ async fn test_mem_provider() { assert!(provider.deregister_table(table_name).unwrap().is_none()); let test_table = EmptyTable::new(Arc::new(Schema::empty())); // register table successfully - assert!(provider - .register_table(table_name.to_string(), Arc::new(test_table)) - .unwrap() - .is_none()); + assert!( + provider + .register_table(table_name.to_string(), Arc::new(test_table)) + .unwrap() + .is_none() + ); assert!(provider.table_exist(table_name)); let other_table = EmptyTable::new(Arc::new(Schema::empty())); let result = provider.register_table(table_name.to_string(), Arc::new(other_table)); diff --git a/datafusion/core/tests/catalog_listing/mod.rs b/datafusion/core/tests/catalog_listing/mod.rs new file mode 100644 index 0000000000000..cb6cac4fb0672 --- /dev/null +++ b/datafusion/core/tests/catalog_listing/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod pruned_partition_list; diff --git a/datafusion/core/tests/catalog_listing/pruned_partition_list.rs b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs new file mode 100644 index 0000000000000..f4782ee13c24d --- /dev/null +++ b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_schema::DataType; +use futures::{FutureExt, StreamExt as _, TryStreamExt as _}; +use object_store::{ObjectStore as _, memory::InMemory, path::Path}; + +use datafusion::execution::SessionStateBuilder; +use datafusion_catalog_listing::helpers::{ + describe_partition, list_partitions, pruned_partition_list, +}; +use datafusion_common::ScalarValue; +use datafusion_datasource::ListingTableUrl; +use datafusion_expr::{Expr, col, lit}; +use datafusion_session::Session; + +#[tokio::test] +async fn test_pruned_partition_list_empty() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/notparquetfile", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .collect::>() + .await; + + assert_eq!(pruned.len(), 0); +} + +#[tokio::test] +async fn test_pruned_partition_list() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/file.parquet", 100), + ("tablepath/mypartition=val2/file.parquet", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/mypartition=val1/other=val3/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/mypartition=val1/file.parquet" + ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/mypartition=val1/other=val3/file.parquet" + ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); +} + +#[tokio::test] +async fn test_pruned_partition_list_multi() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), + ]); + let filter1 = Expr::eq(col("part1"), lit("p1v2")); + let filter2 = Expr::eq(col("part2"), lit("p2v1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter1, filter2], + ".parquet", + &[ + (String::from("part1"), DataType::Utf8), + (String::from("part2"), DataType::Utf8), + ], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file1.parquet" + ); + assert_eq!( + &f1.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] + ); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file2.parquet" + ); + assert_eq!( + &f2.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] + ); +} + +#[tokio::test] +async fn test_list_partition() { + let (store, _) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), + ]); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 0, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec![]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 1, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 2, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ( + "tablepath/part1=p1v2/part2=p2v1", + 2, + vec!["file1.parquet", "file2.parquet"] + ), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), + ] + ); +} + +pub fn make_test_store_and_state( + files: &[(&str, u64)], +) -> (Arc, Arc) { + let memory = InMemory::new(); + + for (name, size) in files { + memory + .put(&Path::from(*name), vec![0; *size as usize].into()) + .now_or_never() + .unwrap() + .unwrap(); + } + + let state = SessionStateBuilder::new().build(); + (Arc::new(memory), Arc::new(state)) +} diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index 976597c8a9ac5..6375d4e25d8eb 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -20,35 +20,43 @@ use std::env; #[test] fn from_env() { - // Note: these must be a single test to avoid interference from concurrent execution - let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; - // valid testing in different cases - for bool_option in ["true", "TRUE", "True", "tRUe"] { - env::set_var(env_key, bool_option); - let config = ConfigOptions::from_env().unwrap(); - env::remove_var(env_key); - assert!(config.optimizer.filter_null_join_keys); - } + unsafe { + // Note: these must be a single test to avoid interference from concurrent execution + let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; + // valid testing in different cases + for bool_option in ["true", "TRUE", "True", "tRUe"] { + env::set_var(env_key, bool_option); + let config = ConfigOptions::from_env().unwrap(); + env::remove_var(env_key); + assert!(config.optimizer.filter_null_join_keys); + } - // invalid testing - env::set_var(env_key, "ttruee"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`"); - env::remove_var(env_key); + // invalid testing + env::set_var(env_key, "ttruee"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`" + ); + env::remove_var(env_key); - let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; + let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; - // for valid testing - env::set_var(env_key, "4096"); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 4096); + // for valid testing + env::set_var(env_key, "4096"); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 4096); - // for invalid testing - env::set_var(env_key, "abc"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string"); + // for invalid testing + env::set_var(env_key, "abc"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string" + ); - env::remove_var(env_key); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 8192); // set to its default value + env::remove_var(env_key); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 8192); // set to its default value + } } diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index edcf039e4e704..bdbe72245323d 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -48,15 +48,15 @@ mod optimizer; /// Run all tests that are found in the `physical_optimizer` directory mod physical_optimizer; -/// Run all tests that are found in the `schema_adapter` directory -mod schema_adapter; - /// Run all tests that are found in the `serde` directory mod serde; /// Run all tests that are found in the `catalog` directory mod catalog; +/// Run all tests that are found in the `catalog_listing` directory +mod catalog_listing; + /// Run all tests that are found in the `tracing` directory mod tracing; diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index cbdc4a448ea41..7b6a3c5fbed75 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -28,11 +28,11 @@ use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_expr::{ - col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, + Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, col, }; use datafusion::physical_plan::{ - collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, collect, }; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; @@ -40,9 +40,9 @@ use datafusion_common::cast::as_primitive_array; use datafusion_common::project_schema; use datafusion_common::stats::Precision; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::PlanProperties; use async_trait::async_trait; use futures::stream::Stream; @@ -316,6 +316,7 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } +#[expect(clippy::needless_pass_by_value)] fn contains_place_holder_exec(plan: Arc) -> bool { if plan.as_any().is::() { true diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54ba..ca1eaa1f958ea 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,7 +35,7 @@ use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::{DataFusionError, internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -134,9 +134,19 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { + let schema_captured = self.schema().clone(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), + futures::stream::iter(self.batches.clone().into_iter().map(move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| batch.schema().index_of(field.name()).ok()) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + })), ))) } diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 403c04f1737e1..820c2a470b376 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -214,6 +214,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(13), @@ -221,6 +222,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(5), + byte_size: Precision::Absent, }, ], }, diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow new file mode 100644 index 0000000000000..bad9e3de4a57f Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow differ diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow new file mode 100644 index 0000000000000..4a07fbfa47f32 Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow differ diff --git a/datafusion/core/tests/data/recursive_cte/closure.csv b/datafusion/core/tests/data/recursive_cte/closure.csv new file mode 100644 index 0000000000000..a31e2bfbf36b6 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/closure.csv @@ -0,0 +1,6 @@ +start,end +1,2 +2,3 +2,4 +2,4 +4,1 \ No newline at end of file diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 265862ff9af8a..014f356cd64cd 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{types::Int32Type, ListArray}; +use arrow::array::{ListArray, types::Int32Type}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ @@ -31,7 +31,7 @@ use datafusion::prelude::*; use datafusion_common::test_util::batches_to_string; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; -use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder}; +use datafusion_expr::{ExprSchemable, LogicalPlanBuilder, table_scan}; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; use datafusion_functions_nested::map::map; use insta::assert_snapshot; @@ -313,10 +313,10 @@ async fn test_fn_arrow_typeof() -> Result<()> { +----------------------+ | arrow_typeof(test.l) | +----------------------+ - | List(nullable Int32) | - | List(nullable Int32) | - | List(nullable Int32) | - | List(nullable Int32) | + | List(Int32) | + | List(Int32) | + | List(Int32) | + | List(Int32) | +----------------------+ "); diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs index 9bd69dfa72b4c..c61fe4fed1615 100644 --- a/datafusion/core/tests/dataframe/describe.rs +++ b/datafusion/core/tests/dataframe/describe.rs @@ -17,7 +17,7 @@ use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_common::test_util::batches_to_string; -use datafusion_common::{test_util::parquet_test_data, Result}; +use datafusion_common::{Result, test_util::parquet_test_data}; use insta::assert_snapshot; #[tokio::test] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 05f5a204c0963..c09db371912b0 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -20,10 +20,10 @@ mod dataframe_functions; mod describe; use arrow::array::{ - record_batch, Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, - FixedSizeListBuilder, Float32Array, Float64Array, Int32Array, Int32Builder, - Int8Array, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, - StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, + Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, + FixedSizeListBuilder, Float32Array, Float64Array, Int8Array, Int32Array, + Int32Builder, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, + StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, record_batch, }; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{ @@ -61,13 +61,13 @@ use datafusion::prelude::{ }; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, - test_table_with_name, + test_table_with_cache_factory, test_table_with_name, }; use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, - DataFusionError, ScalarValue, TableReference, UnnestOptions, + Constraint, Constraints, DFSchema, DataFusionError, ScalarValue, SchemaError, + TableReference, UnnestOptions, assert_contains, internal_datafusion_err, }; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; @@ -76,21 +76,21 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, cast, col, create_udf, exists, + in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, }; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; use datafusion::error::Result as DataFusionResult; use datafusion_functions_window::expr_fn::lag; @@ -305,6 +305,27 @@ async fn select_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn select_columns_with_nonexistent_columns() -> Result<()> { + let t = test_table().await?; + let t2 = t.select_columns(&["canada", "c2", "rocks"]); + + match t2 { + Err(DataFusionError::SchemaError(boxed_err, _)) => { + // Verify it's the first invalid column + match boxed_err.as_ref() { + SchemaError::FieldNotFound { field, .. } => { + assert_eq!(field.name(), "canada"); + } + _ => panic!("Expected SchemaError::FieldNotFound for 'canada'"), + } + } + _ => panic!("Expected SchemaError"), + } + + Ok(()) +} + #[tokio::test] async fn select_expr() -> Result<()> { // build plan using Table API @@ -392,14 +413,14 @@ async fn select_with_periods() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c1 | +------+ | 1 | | 10 | +------+ - "### + " ); Ok(()) @@ -547,14 +568,14 @@ async fn drop_with_quotes() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r#" +------+ | f"c2 | +------+ | 11 | | 2 | +------+ - "### + "# ); Ok(()) @@ -579,14 +600,14 @@ async fn drop_with_periods() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c2 | +------+ | 11 | | 2 | +------+ - "### + " ); Ok(()) @@ -723,23 +744,23 @@ async fn test_aggregate_with_pk() -> Result<()> { assert_snapshot!( physical_plan_to_string(&df).await, - @r###" + @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -766,9 +787,8 @@ async fn test_aggregate_with_pk2() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=Sorted - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 AND name@1 = a - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 AND name@1 = a + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -778,13 +798,13 @@ async fn test_aggregate_with_pk2() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -815,9 +835,8 @@ async fn test_aggregate_with_pk3() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=PartiallySorted([0]) - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -827,13 +846,13 @@ async fn test_aggregate_with_pk3() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -866,9 +885,8 @@ async fn test_aggregate_with_pk4() -> Result<()> { physical_plan_to_string(&df).await, @r" AggregateExec: mode=Single, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - DataSourceExec: partitions=1, partition_sizes=[1] + FilterExec: id@0 = 1 + DataSourceExec: partitions=1, partition_sizes=[1] " ); @@ -876,13 +894,13 @@ async fn test_aggregate_with_pk4() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | id | +----+ | 1 | +----+ - "### + " ); Ok(()) @@ -904,7 +922,7 @@ async fn test_aggregate_alias() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c2 | +----+ @@ -914,7 +932,7 @@ async fn test_aggregate_alias() -> Result<()> { | 5 | | 6 | +----+ - "### + " ); Ok(()) @@ -951,7 +969,7 @@ async fn test_aggregate_with_union() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------------+ | c1 | sum_result | +----+------------+ @@ -961,7 +979,7 @@ async fn test_aggregate_with_union() -> Result<()> { | d | 126 | | e | 121 | +----+------------+ - "### + " ); Ok(()) } @@ -987,7 +1005,7 @@ async fn test_aggregate_subexpr() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----------------+------+ | c2 + Int32(10) | sum | +----------------+------+ @@ -997,7 +1015,7 @@ async fn test_aggregate_subexpr() -> Result<()> { | 15 | 95 | | 16 | -146 | +----------------+------+ - "### + " ); Ok(()) @@ -1020,7 +1038,7 @@ async fn test_aggregate_name_collision() -> Result<()> { // The select expr has the same display_name as the group_expr, // but since they are different expressions, it should fail. .expect_err("Expected error"); - assert_snapshot!(df.strip_backtrace(), @r###"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."###); + assert_snapshot!(df.strip_backtrace(), @r#"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."#); Ok(()) } @@ -1079,33 +1097,33 @@ async fn window_using_aggregates() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df), - @r###" + @r" +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 | + | -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 | | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 | + | -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 | | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 | + | -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ - "### + " ); Ok(()) @@ -1172,7 +1190,7 @@ async fn window_aggregates_with_filter() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---------+---------+---------+---------+---------+----+-----+ | sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val | +---------+---------+---------+---------+---------+----+-----+ @@ -1182,7 +1200,7 @@ async fn window_aggregates_with_filter() -> Result<()> { | 5 | 2.5 | 1 | 4 | 2 | 4 | 4 | | 5 | 2.5 | 1 | 4 | 2 | 5 | -1 | +---------+---------+---------+---------+---------+----+-----+ - "### + " ); Ok(()) @@ -1238,7 +1256,7 @@ async fn test_distinct_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1248,7 +1266,7 @@ async fn test_distinct_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1286,7 +1304,7 @@ async fn test_distinct_on() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1296,7 +1314,7 @@ async fn test_distinct_on() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1321,7 +1339,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1331,7 +1349,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1395,13 +1413,13 @@ async fn join_coercion_unnamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 10 | d | +----+------+ - "### + " ); Ok(()) } @@ -1420,13 +1438,13 @@ async fn join_on() -> Result<()> { [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))], )?; - assert_snapshot!(join.logical_plan(), @r###" + assert_snapshot!(join.logical_plan(), @r" Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2 Projection: a.c1, a.c2 TableScan: a Projection: b.c1, b.c2 TableScan: b - "###); + "); Ok(()) } @@ -1449,7 +1467,11 @@ async fn join_on_filter_datatype() -> Result<()> { let err = join.into_optimized_plan().unwrap_err(); assert_snapshot!( err.strip_backtrace(), - @"type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + @r" + type_coercion + caused by + Error during planning: Join condition must be boolean type, but got Utf8 + " ); Ok(()) } @@ -1627,7 +1649,9 @@ async fn register_table() -> Result<()> { let df_impl = DataFrame::new(ctx.state(), df.logical_plan().clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone().into_view())?; + let table_provider = df_impl.clone().into_view(); + assert_eq!(table_provider.table_type(), TableType::View); + ctx.register_table("test_table", table_provider)?; // pull the table out let table = ctx.table("test_table").await?; @@ -1644,7 +1668,7 @@ async fn register_table() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+-----------------------------+ | c1 | sum(aggregate_test_100.c12) | +----+-----------------------------+ @@ -1654,13 +1678,13 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+-----------------------------+ - "### + " ); // the results are the same as the results from the view, modulo the leaf table name assert_snapshot!( batches_to_sort_string(table_results), - @r###" + @r" +----+---------------------+ | c1 | sum(test_table.c12) | +----+---------------------+ @@ -1670,7 +1694,7 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+---------------------+ - "### + " ); Ok(()) } @@ -1719,7 +1743,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1730,7 +1754,7 @@ async fn with_column() -> Result<()> { | a | 3 | 14 | 17 | | a | 3 | 17 | 20 | +----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten @@ -1742,7 +1766,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite), - @r###" + @r" +-----+----+-----+-----+ | c1 | c2 | c3 | sum | +-----+----+-----+-----+ @@ -1753,7 +1777,7 @@ async fn with_column() -> Result<()> { | 17 | 3 | 14 | 17 | | 20 | 3 | 17 | 20 | +-----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten using same name as reference @@ -1765,7 +1789,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite_self), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1776,7 +1800,7 @@ async fn with_column() -> Result<()> { | a | 4 | 14 | 17 | | a | 4 | 17 | 20 | +----+----+-----+-----+ - "### + " ); Ok(()) @@ -1804,14 +1828,14 @@ async fn test_window_function_with_column() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+---+ | c1 | c2 | c3 | s | r | +----+----+-----+-----+---+ | c | 2 | 1 | 3 | 1 | | d | 5 | -40 | -35 | 2 | +----+----+-----+-----+---+ - "### + " ); Ok(()) @@ -1846,13 +1870,13 @@ async fn with_column_join_same_columns() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+ | c1 | c1 | +----+----+ | a | a | +----+----+ - "### + " ); let df_with_column = df.clone().with_column("new_column", lit(true))?; @@ -1875,7 +1899,7 @@ async fn with_column_join_same_columns() -> Result<()> { assert_snapshot!( df_with_column.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1, t2.c1, Boolean(true) AS new_column Sort: t1.c1 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -1883,20 +1907,20 @@ async fn with_column_join_same_columns() -> Result<()> { TableScan: aggregate_test_100 projection=[c1] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1] - "### + " ); let df_results = df_with_column.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+------------+ | c1 | c1 | new_column | +----+----+------------+ | a | a | true | +----+----+------------+ - "### + " ); Ok(()) @@ -1946,13 +1970,13 @@ async fn with_column_renamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(batches), - @r###" + @r" +-----+-----+-----+-------+ | one | two | c3 | total | +-----+-----+-----+-------+ | a | 3 | -72 | -69 | +-----+-----+-----+-------+ - "### + " ); Ok(()) @@ -2017,13 +2041,13 @@ async fn with_column_renamed_join() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+----+----+-----+ | c1 | c2 | c3 | c1 | c2 | c3 | +----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +----+----+-----+----+----+-----+ - "### + " ); let df_renamed = df.clone().with_column_renamed("t1.c1", "AAA")?; @@ -2046,7 +2070,7 @@ async fn with_column_renamed_join() -> Result<()> { assert_snapshot!( df_renamed.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3 Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -2054,20 +2078,20 @@ async fn with_column_renamed_join() -> Result<()> { TableScan: aggregate_test_100 projection=[c1, c2, c3] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1, c2, c3] - "### + " ); let df_results = df_renamed.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +-----+----+-----+----+----+-----+ | AAA | c2 | c3 | c1 | c2 | c3 | +-----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +-----+----+-----+----+----+-----+ - "### + " ); Ok(()) @@ -2102,13 +2126,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(res), - @r###" + @r" +---------+ | CoLuMn1 | +---------+ | a | +---------+ - "### + " ); let df_renamed = df_renamed @@ -2118,13 +2142,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_renamed), - @r###" + @r" +----+ | c1 | +----+ | a | +----+ - "### + " ); Ok(()) @@ -2162,19 +2186,19 @@ async fn describe_lookup_via_quoted_identifier() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&describe_result.clone().collect().await?), - @r###" - +------------+--------------+ - | describe | CoLu.Mn["1"] | - +------------+--------------+ - | count | 1 | - | max | a | - | mean | null | - | median | null | - | min | a | - | null_count | 0 | - | std | null | - +------------+--------------+ - "### + @r#" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "# ); Ok(()) @@ -2192,13 +2216,13 @@ async fn cast_expr_test() -> Result<()> { df.clone().show().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); Ok(()) @@ -2214,12 +2238,14 @@ async fn row_writer_resize_test() -> Result<()> { let data = RecordBatch::try_new( schema, - vec![ - Arc::new(StringArray::from(vec![ - Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), - Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), - ])) - ], + vec![Arc::new(StringArray::from(vec![ + Some( + "2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ), + Some( + "3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800", + ), + ]))], )?; let ctx = SessionContext::new(); @@ -2258,14 +2284,14 @@ async fn with_column_name() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+-------+ | f.c1 | f.c2 | +------+-------+ | 1 | hello | | 10 | hello | +------+-------+ - "### + " ); Ok(()) @@ -2301,13 +2327,13 @@ async fn cache_test() -> Result<()> { let cached_df_results = cached_df.collect().await?; assert_snapshot!( batches_to_sort_string(&cached_df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); assert_eq!(&df_results, &cached_df_results); @@ -2315,6 +2341,29 @@ async fn cache_test() -> Result<()> { Ok(()) } +#[tokio::test] +async fn cache_producer_test() -> Result<()> { + let df = test_table_with_cache_factory() + .await? + .select_columns(&["c2", "c3"])? + .limit(0, Some(1))? + .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + + let cached_df = df.clone().cache().await?; + + assert_snapshot!( + cached_df.clone().into_optimized_plan().unwrap(), + @r" + CacheNode + Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum + Projection: aggregate_test_100.c2, aggregate_test_100.c3 + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100, fetch=1 + " + ); + Ok(()) +} + #[tokio::test] async fn partition_aware_union() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; @@ -2584,13 +2633,13 @@ async fn filtered_aggr_with_param_values() -> Result<()> { let df_results = df?.collect().await?; assert_snapshot!( batches_to_string(&df_results), - @r###" + @r" +------------------------------------------------+ | count(table1.c2) FILTER (WHERE table1.c3 > $1) | +------------------------------------------------+ | 54 | +------------------------------------------------+ - "### + " ); Ok(()) @@ -2638,7 +2687,7 @@ async fn write_parquet_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2648,7 +2697,7 @@ async fn write_parquet_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) @@ -2696,7 +2745,7 @@ async fn write_csv_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2706,7 +2755,7 @@ async fn write_csv_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2753,7 +2802,7 @@ async fn write_json_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2763,7 +2812,7 @@ async fn write_json_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2812,7 +2861,7 @@ async fn write_table_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+ | tablecol1 | +-----------+ @@ -2822,7 +2871,7 @@ async fn write_table_with_order() -> Result<()> { | x | | z | +-----------+ - "### + " ); Ok(()) } @@ -2849,7 +2898,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" + @r" +---------------+------------------------------------------------------------------------------------------------------------+ | plan_type | plan | +---------------+------------------------------------------------------------------------------------------------------------+ @@ -2863,36 +2912,32 @@ async fn test_count_wildcard_on_sort() -> Result<()> { | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | | | | +---------------+------------------------------------------------------------------------------------------------------------+ - "### + " ); assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+--------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+--------------------------------------------------------------------------------+ - | logical_plan | Sort: count(*) ASC NULLS LAST | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t1 projection=[b] | - | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | - | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+--------------------------------------------------------------------------------+ - "### + @r" + +---------------+----------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) ASC NULLS LAST | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------+ + " ); Ok(()) } @@ -2910,23 +2955,22 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: count(Int64(1)) AS count(*) | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Projection: count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -2956,22 +3000,21 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -3077,15 +3120,17 @@ async fn test_count_wildcard_on_window() -> Result<()> { let df_results = ctx .table("t1") .await? - .select(vec![count_all_window() - .order_by(vec![Sort::new(col("a"), false, true)]) - .window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? + .select(vec![ + count_all_window() + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap(), + ])? .explain(false, false)? .collect() .await?; @@ -3113,30 +3158,29 @@ async fn test_count_wildcard_on_window() -> Result<()> { #[tokio::test] // Test with `repartition_sorts` disabled, causing a full resort of the data -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, - @r#" + @r" AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] - CoalescePartitionsExec - AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] - UnionExec - DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + UnionExec + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - "#); + "); Ok(()) } -#[ignore] // See https://github.com/apache/datafusion/issues/18380 #[tokio::test] // Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, - @r#" + @r" AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted SortPreservingMergeExec: [id@0 ASC NULLS LAST] AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted @@ -3144,53 +3188,7 @@ async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_reparti DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - "#); - - // 💥 Doesn't pass, and generates this plan: - // - // AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted - // SortPreservingMergeExec: [id@0 ASC NULLS LAST] - // SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] - // AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] - // UnionExec - // DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet - // DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet - // - // - // === Excerpt from the verbose explain === - // - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - // | plan_type | plan | - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - // | initial_physical_plan | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // ... - // | physical_plan after EnforceDistribution | OutputRequirementExec: order_by=[], dist_by=Unspecified | - // | | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | CoalescePartitionsExec | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // | | | - // | physical_plan after CombinePartialFinalAggregate | SAME TEXT AS ABOVE - // | | | - // | physical_plan after EnforceSorting | OutputRequirementExec: order_by=[], dist_by=Unspecified | - // | | AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted | - // | | SortPreservingMergeExec: [id@0 ASC NULLS LAST] | - // | | SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] | - // | | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[] | - // | | UnionExec | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet | - // | | DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet | - // ... - // +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "); Ok(()) } @@ -3275,7 +3273,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" + @r" +---------------+-----------------------------------------------------+ | plan_type | plan | +---------------+-----------------------------------------------------+ @@ -3286,7 +3284,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+-----------------------------------------------------+ - "### + " ); // add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node. @@ -3301,7 +3299,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" + @r" +---------------+---------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------+ @@ -3311,7 +3309,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+---------------------------------------------------------------+ - "### + " ); Ok(()) @@ -3331,32 +3329,31 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ " ); @@ -3388,32 +3385,31 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ " ); @@ -3498,7 +3494,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3507,7 +3503,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) @@ -3545,7 +3541,7 @@ async fn sort_on_distinct_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3553,7 +3549,7 @@ async fn sort_on_distinct_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) } @@ -3684,14 +3680,14 @@ async fn filter_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+ | a | +------+ | true | | true | +------+ - "### + " ); Ok(()) @@ -3720,7 +3716,7 @@ async fn select_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------+ | a | +-------+ @@ -3729,7 +3725,7 @@ async fn select_with_alias_overwrite() -> Result<()> { | true | | false | +-------+ - "### + " ); Ok(()) @@ -3755,7 +3751,7 @@ async fn test_grouping_sets() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+-----+---------------+ | a | b | count(test.a) | +-----------+-----+---------------+ @@ -3771,7 +3767,7 @@ async fn test_grouping_sets() -> Result<()> { | 123AbcDef | | 1 | | 123AbcDef | 100 | 1 | +-----------+-----+---------------+ - "### + " ); Ok(()) @@ -3798,7 +3794,7 @@ async fn test_grouping_sets_count() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+-----------------+ | c1 | c2 | count(Int32(1)) | +----+----+-----------------+ @@ -3813,7 +3809,7 @@ async fn test_grouping_sets_count() -> Result<()> { | b | | 19 | | a | | 21 | +----+----+-----------------+ - "### + " ); Ok(()) @@ -3847,7 +3843,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+--------+---------------------+ | c1 | c2 | sum_c3 | avg_c3 | +----+----+--------+---------------------+ @@ -3887,7 +3883,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { | a | 2 | -46 | -15.333333333333334 | | a | 1 | -88 | -17.6 | +----+----+--------+---------------------+ - "### + " ); Ok(()) @@ -3924,25 +3920,25 @@ async fn join_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32] Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32] TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+---+----+---+---+ | a | a | b | c | b | c | +----+----+---+----+---+---+ | 1 | 3 | a | 10 | a | 1 | | 11 | 13 | c | 30 | c | 3 | +----+----+---+----+---+---+ - "### + " ); Ok(()) @@ -3969,27 +3965,27 @@ async fn right_semi_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightSemi Join: t1.a = t2.a [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-----+---+---+ | a | b | c | +-----+---+---+ | 10 | b | 2 | | 100 | d | 4 | +-----+---+---+ - "### + " ); Ok(()) @@ -4016,26 +4012,26 @@ async fn right_anti_filter_push_down() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightAnti Join: t1.a = t2.a Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+---+---+ | a | b | c | +----+---+---+ | 13 | c | 3 | | 3 | a | 1 | +----+---+---+ - "### + " ); Ok(()) @@ -4048,37 +4044,37 @@ async fn unnest_columns() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+--------------------------+ - | shape_id | points | tags | - +----------+---------------------------------+--------------------------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+---------------------------------+--------------------------+ - "###); + @r" + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; let results = df.unnest_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+------+ - | shape_id | points | tags | - +----------+---------------------------------+------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+---------------------------------+------+ - "###); + @r" + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4090,19 +4086,19 @@ async fn unnest_columns() -> Result<()> { let results = df.unnest_columns(&["points"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+--------------------------+ - | shape_id | points | tags | - +----------+----------------+--------------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: 5, y: -8} | [tag1] | - | 2 | {x: -2, y: -8} | [tag1] | - | 2 | {x: 6, y: 2} | [tag1] | - | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | - | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+----------------+--------------------------+ - "###); + @r" + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4118,27 +4114,27 @@ async fn unnest_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+------+ - | shape_id | points | tags | - +----------+----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: 5, y: -8} | tag1 | - | 2 | {x: -2, y: -8} | tag1 | - | 2 | {x: 6, y: 2} | tag1 | - | 3 | {x: -2, y: 5} | tag1 | - | 3 | {x: -2, y: 5} | tag2 | - | 3 | {x: -2, y: 5} | tag3 | - | 3 | {x: -2, y: 5} | tag4 | - | 3 | {x: -9, y: -7} | tag1 | - | 3 | {x: -9, y: -7} | tag2 | - | 3 | {x: -9, y: -7} | tag3 | - | 3 | {x: -9, y: -7} | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+----------------+------+ - "###); + @r" + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -4178,7 +4174,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -4186,7 +4182,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | y | y | | z | z | +-----------------+---------+ - "### + " ); // make_array(dict_encoded_string,literal string) @@ -4206,7 +4202,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -4217,7 +4213,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | z | z | | fixed_string | z | +-----------------+---------+ - "### + " ); Ok(()) } @@ -4228,7 +4224,7 @@ async fn unnest_column_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +--------+----+ | list | id | +--------+----+ @@ -4237,7 +4233,7 @@ async fn unnest_column_nulls() -> Result<()> { | [] | C | | [3] | D | +--------+----+ - "### + " ); // Unnest, preserving nulls (row with B is preserved) @@ -4250,7 +4246,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -4259,7 +4255,7 @@ async fn unnest_column_nulls() -> Result<()> { | | B | | 3 | D | +------+----+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -4269,7 +4265,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -4277,7 +4273,7 @@ async fn unnest_column_nulls() -> Result<()> { | 2 | A | | 3 | D | +------+----+ - "### + " ); Ok(()) @@ -4294,7 +4290,7 @@ async fn unnest_fixed_list() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4305,7 +4301,7 @@ async fn unnest_fixed_list() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -4316,7 +4312,7 @@ async fn unnest_fixed_list() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4331,7 +4327,7 @@ async fn unnest_fixed_list() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4348,7 +4344,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4359,7 +4355,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -4370,7 +4366,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4383,7 +4379,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4419,7 +4415,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4430,7 +4426,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -4440,7 +4436,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4457,7 +4453,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4471,17 +4467,17 @@ async fn unnest_aggregate_columns() -> Result<()> { let results = df.select_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +--------------------------+ - | tags | - +--------------------------+ - | [tag1, tag2, tag3, tag4] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - | [tag1] | - +--------------------------+ - "### + @r" + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ + " ); let df = table_with_nested_types(NUM_ROWS).await?; @@ -4492,13 +4488,13 @@ async fn unnest_aggregate_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-------------+ | count(tags) | +-------------+ | 11 | +-------------+ - "### + " ); Ok(()) @@ -4571,7 +4567,7 @@ async fn unnest_array_agg() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4585,7 +4581,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4599,7 +4595,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------------+ | shape_id | tag_id | +----------+--------------+ @@ -4607,7 +4603,7 @@ async fn unnest_array_agg() -> Result<()> { | 2 | [21, 22, 23] | | 3 | [31, 32, 33] | +----------+--------------+ - "### + " ); // Unnesting again should produce the original batch. @@ -4623,7 +4619,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4637,7 +4633,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); Ok(()) @@ -4667,7 +4663,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4681,7 +4677,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4711,7 +4707,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+ | shape_id | +----------+ @@ -4725,7 +4721,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | | 3 | +----------+ - "### + " ); Ok(()) @@ -4766,7 +4762,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4780,7 +4776,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 4 | c | | | | | d | +------+------------+------------+--------+ - "### + " ); // Test with `preserve_nulls = false`` @@ -4797,7 +4793,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4810,7 +4806,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 3 | c | | | | 4 | c | +------+------------+------------+--------+ - "### + " ); Ok(()) @@ -4839,7 +4835,7 @@ async fn unnest_non_nullable_list() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+ | c1 | +----+ @@ -4847,7 +4843,7 @@ async fn unnest_non_nullable_list() -> Result<()> { | 2 | | | +----+ - "### + " ); Ok(()) @@ -4892,7 +4888,7 @@ async fn test_read_batches() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+--------+ | id | number | +----+--------+ @@ -4905,7 +4901,7 @@ async fn test_read_batches() -> Result<()> { | 5 | 3.33 | | 5 | 6.66 | +----+--------+ - "### + " ); Ok(()) } @@ -4926,10 +4922,10 @@ async fn test_read_batches_empty() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" ++ ++ - "### + " ); Ok(()) } @@ -4978,14 +4974,14 @@ async fn consecutive_projection_same_schema() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+----+ | id | t | t2 | +----+----+----+ | 0 | | | | 1 | 10 | 10 | +----+----+----+ - "### + " ); Ok(()) @@ -5299,13 +5295,13 @@ async fn test_array_agg() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------------------------------------+ | array_agg(test.a) | +-------------------------------------+ | [abcDEF, abc123, CBAdef, 123AbcDef] | +-------------------------------------+ - "### + " ); Ok(()) @@ -5373,10 +5369,10 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { // N.B., the test is basically `SELECT 1 as a WHERE a = 3;` which returns no results. assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" ++ ++ - "### + " ); Ok(()) @@ -5425,20 +5421,20 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, @r" - Projection: Int32(3) AS $1 [$1:Null;N] + Projection: Int32(3) AS $1 [$1:Int32] EmptyRelation: rows=1 [] " ); assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +----+ | $1 | +----+ | 3 | +----+ - "### + " ); Ok(()) @@ -5505,13 +5501,13 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +-----+ | a | +-----+ | foo | +-----+ - "### + " ); Ok(()) @@ -5569,13 +5565,13 @@ async fn write_partitioned_parquet_results() -> Result<()> { let results = filter_df.collect().await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | c1 | +-----+ | abc | +-----+ - "### + " ); // Read the entire set of parquet files @@ -5591,14 +5587,14 @@ async fn write_partitioned_parquet_results() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-----+-----+ | c1 | c2 | +-----+-----+ | abc | 123 | | def | 456 | +-----+-----+ - "### + " ); Ok(()) @@ -5755,7 +5751,7 @@ async fn sparse_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5766,14 +5762,14 @@ async fn sparse_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5781,14 +5777,14 @@ async fn sparse_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5796,7 +5792,7 @@ async fn sparse_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5838,7 +5834,7 @@ async fn dense_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5849,14 +5845,14 @@ async fn dense_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5864,14 +5860,14 @@ async fn dense_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5879,7 +5875,7 @@ async fn dense_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5911,7 +5907,7 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5923,14 +5919,14 @@ async fn boolean_dictionary_as_filter() { | true | | false | +---------+ - "### + " ); let result_df = df.clone().filter(col("my_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5938,7 +5934,7 @@ async fn boolean_dictionary_as_filter() { | true | | true | +---------+ - "### + " ); // test nested dictionary @@ -5969,26 +5965,26 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | | false | +----------------+ - "### + " ); let result_df = df.clone().filter(col("my_nested_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | +----------------+ - "### + " ); } @@ -6066,11 +6062,11 @@ async fn test_alias() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: table_alias [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let df = df.select(vec![ @@ -6079,7 +6075,7 @@ async fn test_alias() -> Result<()> { ])?; assert_snapshot!( batches_to_sort_string(&df.collect().await.unwrap()), - @r###" + @r" +-----------+---------------------------------+ | a | table_alias.b + table_alias.one | +-----------+---------------------------------+ @@ -6088,7 +6084,7 @@ async fn test_alias() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+---------------------------------+ - "### + " ); Ok(()) } @@ -6118,7 +6114,7 @@ async fn test_alias_self_join() -> Result<()> { let joined = left.join(right, JoinType::Full, &["a"], &["a"], None)?; assert_snapshot!( batches_to_sort_string(&joined.collect().await.unwrap()), - @r###" + @r" +-----------+-----+-----------+-----+ | a | b | a | b | +-----------+-----+-----------+-----+ @@ -6127,7 +6123,7 @@ async fn test_alias_self_join() -> Result<()> { | abc123 | 10 | abc123 | 10 | | abcDEF | 1 | abcDEF | 1 | +-----------+-----+-----------+-----+ - "### + " ); Ok(()) } @@ -6140,14 +6136,14 @@ async fn test_alias_empty() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: [a:Utf8, b:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); assert_snapshot!( batches_to_sort_string(&df.select(vec![col("a"), col("b")])?.collect().await.unwrap()), - @r###" + @r" +-----------+-----+ | a | b | +-----------+-----+ @@ -6156,7 +6152,7 @@ async fn test_alias_empty() -> Result<()> { | abc123 | 10 | | abcDEF | 1 | +-----------+-----+ - "### + " ); Ok(()) @@ -6175,12 +6171,12 @@ async fn test_alias_nested() -> Result<()> { .into_optimized_plan()? .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: alias2 [a:Utf8, b:Int32, one:Int32] SubqueryAlias: alias1 [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test projection=[a, b] [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let select1 = df @@ -6189,7 +6185,7 @@ async fn test_alias_nested() -> Result<()> { assert_snapshot!( batches_to_sort_string(&select1.collect().await.unwrap()), - @r###" + @r" +-----------+-----------------------+ | a | alias2.b + alias2.one | +-----------+-----------------------+ @@ -6198,7 +6194,7 @@ async fn test_alias_nested() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+-----------------------+ - "### + " ); // Only the outermost alias is visible @@ -6318,7 +6314,10 @@ async fn test_insert_into_checking() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8"); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8" + ); Ok(()) } @@ -6365,7 +6364,7 @@ async fn test_fill_null() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6373,7 +6372,7 @@ async fn test_fill_null() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) @@ -6393,7 +6392,7 @@ async fn test_fill_null_all_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6401,7 +6400,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); // Fill column "a" null values with a value that cannot be cast to Int32. @@ -6410,7 +6409,7 @@ async fn test_fill_null_all_columns() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -6418,7 +6417,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) } @@ -6450,7 +6449,10 @@ async fn test_insert_into_casting_support() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8."); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8." + ); // Testing case2: // Inserting query schema mismatch: Expected table field 'a' with type Utf8View, but got 'a' with type Utf8. @@ -6488,14 +6490,14 @@ async fn test_insert_into_casting_support() -> Result<()> { assert_snapshot!( batches_to_string(&res), - @r###" + @r" +------+ | a | +------+ | a123 | | b456 | +------+ - "### + " ); Ok(()) } @@ -6631,13 +6633,13 @@ async fn test_copy_to_preserves_order() -> Result<()> { // Expect that input to the DataSinkExec is sorted correctly assert_snapshot!( physical_plan_format, - @r###" + @r" UnionExec DataSinkExec: sink=CsvSink(file_groups=[]) SortExec: expr=[column1@0 DESC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[1] DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); Ok(()) } diff --git a/datafusion/core/tests/datasource/object_store_access.rs b/datafusion/core/tests/datasource/object_store_access.rs index f89ca9e049147..2e1b1484076d9 100644 --- a/datafusion/core/tests/datasource/object_store_access.rs +++ b/datafusion/core/tests/datasource/object_store_access.rs @@ -98,6 +98,59 @@ async fn create_multi_file_csv_file() { ); } +#[tokio::test] +async fn multi_query_multi_file_csv_file() { + let test = Test::new().with_multi_file_csv().await; + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); + + // the second query should re-use the cached LIST results and should not reissue LIST + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + #[tokio::test] async fn query_multi_csv_file() { let test = Test::new().with_multi_file_csv().await; @@ -145,17 +198,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 13 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 4 + - LIST prefix=data - GET (opts) path=data/a=1/b=10/c=100/file_1.csv - GET (opts) path=data/a=2/b=20/c=200/file_2.csv - GET (opts) path=data/a=3/b=30/c=300/file_3.csv @@ -174,10 +218,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 4 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 + Total Requests: 2 + - LIST prefix=data/a=2 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -194,17 +236,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 2 + - LIST prefix=data - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -221,17 +254,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 2 + - LIST prefix=data - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -248,9 +272,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 3 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 + Total Requests: 2 + - LIST prefix=data/a=2/b=20 - GET (opts) path=data/a=2/b=20/c=200/file_2.csv " ); @@ -267,17 +290,8 @@ async fn query_partitioned_csv_file() { +---------+-------+-------+---+----+-----+ ------- Object Store Request Summary ------- RequestCountingObjectStore() - Total Requests: 11 - - LIST (with delimiter) prefix=data - - LIST (with delimiter) prefix=data/a=1 - - LIST (with delimiter) prefix=data/a=2 - - LIST (with delimiter) prefix=data/a=3 - - LIST (with delimiter) prefix=data/a=1/b=10 - - LIST (with delimiter) prefix=data/a=2/b=20 - - LIST (with delimiter) prefix=data/a=3/b=30 - - LIST (with delimiter) prefix=data/a=1/b=10/c=100 - - LIST (with delimiter) prefix=data/a=2/b=20/c=200 - - LIST (with delimiter) prefix=data/a=3/b=30/c=300 + Total Requests: 2 + - LIST prefix=data - GET (opts) path=data/a=1/b=10/c=100/file_1.csv " ); diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs index b6f406e967509..27dacf598c2c0 100644 --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -22,25 +22,25 @@ use datafusion::common::NullEquality; use datafusion::functions_aggregate::sum; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan; +use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::execution_plan::Boundedness; -use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_common::{exec_datafusion_err, DataFusionError, JoinType, ScalarValue}; +use datafusion_common::{DataFusionError, JoinType, ScalarValue, exec_datafusion_err}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::expressions::{ - binary, col, lit, BinaryExpr, Column, Literal, + BinaryExpr, Column, Literal, binary, col, lit, }; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coop::make_cooperative; use datafusion_physical_plan::filter::FilterExec; @@ -64,13 +64,14 @@ use std::time::Duration; use tokio::runtime::{Handle, Runtime}; use tokio::select; -#[derive(Debug)] +#[derive(Debug, Clone)] struct RangeBatchGenerator { schema: SchemaRef, value_range: Range, boundedness: Boundedness, batch_size: usize, poll_count: usize, + original_range: Range, } impl std::fmt::Display for RangeBatchGenerator { @@ -110,6 +111,13 @@ impl LazyBatchGenerator for RangeBatchGenerator { RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; Ok(Some(batch)) } + + fn reset_state(&self) -> Arc> { + let mut new = self.clone(); + new.poll_count = 0; + new.value_range = new.original_range.clone(); + Arc::new(RwLock::new(new)) + } } fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { @@ -136,16 +144,17 @@ fn make_lazy_exec_with_range( }; // Instantiate the generator with the batch and limit - let gen = RangeBatchGenerator { + let batch_gen = RangeBatchGenerator { schema: Arc::clone(&schema), boundedness, - value_range: range, + value_range: range.clone(), batch_size: 8192, poll_count: 0, + original_range: range, }; // Wrap the generator in a trait object behind Arc> - let generator: Arc> = Arc::new(RwLock::new(gen)); + let generator: Arc> = Arc::new(RwLock::new(batch_gen)); // Create a LazyMemoryExec with one partition using our generator let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); @@ -170,7 +179,7 @@ async fn agg_no_grouping_yields( let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![Arc::new( AggregateExprBuilder::new( sum::sum_udaf(), @@ -204,7 +213,7 @@ async fn agg_grouping_yields( let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![], false), vec![Arc::new( AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) .schema(inf.schema()) @@ -240,6 +249,7 @@ async fn agg_grouped_topk_yields( vec![(group, "group".to_string())], vec![], vec![vec![false]], + false, ), vec![Arc::new( AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) @@ -545,6 +555,7 @@ async fn interleave_then_aggregate_yields( vec![], // no GROUP BY columns vec![], // no GROUP BY expressions vec![], // no GROUP BY physical expressions + false, ), vec![Arc::new(aggregate_expr)], vec![None], // no “distinct” flags @@ -653,7 +664,7 @@ async fn join_agg_yields( let proj_expr = vec![ProjectionExpr::new( Arc::new(Column::new_with_schema("value", &input_schema)?) as _, - "value".to_string(), + "value", )]; let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); @@ -676,7 +687,7 @@ async fn join_agg_yields( let aggr = Arc::new(AggregateExec::try_new( AggregateMode::Single, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![Arc::new(aggregate_expr)], vec![None], projection, diff --git a/datafusion/core/tests/execution/datasource_split.rs b/datafusion/core/tests/execution/datasource_split.rs index 0b90c6f326168..370249cd8044e 100644 --- a/datafusion/core/tests/execution/datasource_split.rs +++ b/datafusion/core/tests/execution/datasource_split.rs @@ -22,7 +22,7 @@ use arrow::{ }; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_execution::TaskContext; -use datafusion_physical_plan::{common::collect, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, common::collect}; use std::sync::Arc; /// Helper function to create a memory source with the given batch size and collect all batches diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index ef2e263f2c467..3eaa3fb2ed5e6 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -20,7 +20,7 @@ use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion::datasource::{provider_as_source, ViewTable}; +use datafusion::datasource::{ViewTable, provider_as_source}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs index 8770b2a201051..f33ef87aa3023 100644 --- a/datafusion/core/tests/execution/mod.rs +++ b/datafusion/core/tests/execution/mod.rs @@ -18,3 +18,4 @@ mod coop; mod datasource_split; mod logical_plan; +mod register_arrow; diff --git a/datafusion/core/tests/execution/register_arrow.rs b/datafusion/core/tests/execution/register_arrow.rs new file mode 100644 index 0000000000000..4ce16dc0906c1 --- /dev/null +++ b/datafusion/core/tests/execution/register_arrow.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for register_arrow API + +use datafusion::{execution::options::ArrowReadOptions, prelude::*}; +use datafusion_common::Result; + +#[tokio::test] +async fn test_register_arrow_auto_detects_format() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_format", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_format", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let file_result = ctx.sql("SELECT * FROM file_format ORDER BY f0").await?; + let stream_result = ctx.sql("SELECT * FROM stream_format ORDER BY f0").await?; + + let file_batches = file_result.collect().await?; + let stream_batches = stream_result.collect().await?; + + assert_eq!(file_batches.len(), stream_batches.len()); + assert_eq!(file_batches[0].schema(), stream_batches[0].schema()); + + let file_rows: usize = file_batches.iter().map(|b| b.num_rows()).sum(); + let stream_rows: usize = stream_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(file_rows, stream_rows); + + Ok(()) +} + +#[tokio::test] +async fn test_register_arrow_join_file_and_stream() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_table", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_table", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let result = ctx + .sql( + "SELECT a.f0, a.f1, b.f0, b.f1 + FROM file_table a + JOIN stream_table b ON a.f0 = b.f0 + WHERE a.f0 <= 2 + ORDER BY a.f0", + ) + .await?; + let batches = result.collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + + Ok(()) +} diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 84e644480a4fd..90c1b96749b3c 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -16,17 +16,17 @@ // under the License. use arrow::array::{ - builder::{ListBuilder, StringBuilder}, ArrayRef, Int64Array, RecordBatch, StringArray, StructArray, + builder::{ListBuilder, StringBuilder}, }; use arrow::datatypes::{DataType, Field}; use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::NullTreatment; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; @@ -36,6 +36,7 @@ use datafusion_optimizer::simplify_expressions::ExprSimplifier; use std::sync::{Arc, LazyLock}; mod parse_sql_expr; +#[expect(clippy::needless_pass_by_value)] mod simplification; #[test] @@ -384,6 +385,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. +#[expect(clippy::needless_pass_by_value)] fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { let batch = &TEST_BATCH; let df_schema = DFSchema::try_from(batch.schema()).unwrap(); diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 92c18204324f7..b0d8b3a349ae2 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -19,9 +19,9 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::DFSchema; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::Expr; use datafusion_expr::col; use datafusion_expr::lit; -use datafusion_expr::Expr; use datafusion_sql::unparser::Unparser; /// A schema like: /// diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 46c36c6abdacc..a42dfc951da0d 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -24,15 +24,15 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, TimeZone, Utc}; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; -use datafusion_common::cast::as_int32_array; use datafusion_common::ScalarValue; +use datafusion_common::cast::as_int32_array; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, - ScalarUDF, Volatility, + Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarUDF, + Volatility, table_scan, }; use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; @@ -243,10 +243,10 @@ fn to_timestamp_expr_folded() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r#" Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8("2020-09-08T12:00:00+00:00")) TableScan: test - "### + "# ); Ok(()) } @@ -273,10 +273,10 @@ fn now_less_than_timestamp() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: Boolean(true) TableScan: test - "### + " ); Ok(()) } @@ -312,10 +312,10 @@ fn select_date_plus_interval() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") TableScan: test - "### + "# ); Ok(()) } @@ -334,10 +334,10 @@ fn simplify_project_scalar_fn() -> Result<()> { let actual = formatter.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: test.f AS power(test.f,Float64(1)) TableScan: test - "### + " ); Ok(()) } diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 141a3f3b75586..36cc769417dbc 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -22,21 +22,21 @@ mod unix_test { use std::fs::File; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use arrow::array::Array; use arrow::csv::ReaderBuilder; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; use datafusion_common::instant::Instant; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::SortExpr; use futures::StreamExt; @@ -44,7 +44,7 @@ mod unix_test { use nix::unistd; use tempfile::TempDir; use tokio::io::AsyncWriteExt; - use tokio::task::{spawn_blocking, JoinHandle}; + use tokio::task::{JoinHandle, spawn_blocking}; /// Makes a TableProvider for a fifo file fn fifo_table( diff --git a/datafusion/core/tests/fuzz.rs b/datafusion/core/tests/fuzz.rs index 92646e8b37636..5e94f12b5805d 100644 --- a/datafusion/core/tests/fuzz.rs +++ b/datafusion/core/tests/fuzz.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -/// Run all tests that are found in the `fuzz_cases` directory +/// Run all tests that are found in the `fuzz_cases` directory. +/// Fuzz tests are slow and gated behind the `extended_tests` feature. +/// Run with: cargo test --features extended_tests +#[cfg(feature = "extended_tests")] mod fuzz_cases; #[cfg(test)] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4e04da26f70b6..97d1db5728cf3 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -24,37 +24,37 @@ use crate::fuzz_cases::aggregation_fuzzer::{ }; use arrow::array::{ - types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, StringArray, + types::Int64Type, }; use arrow::compute::concat_batches; use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; +use datafusion::datasource::MemTable; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::expressions::{Column, col, lit}; use datafusion_physical_plan::InputOrderMode; -use test_utils::{add_empty_batches, StringBatchGenerator}; +use test_utils::{StringBatchGenerator, add_empty_batches}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::metrics::MetricValue; -use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect, displayable}; use rand::rngs::StdRng; -use rand::{random, rng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, random, rng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -326,15 +326,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .unwrap(), ); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .schema(Arc::clone(&schema)) - .alias("sum1") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .build() + .map(Arc::new) + .unwrap(), + ]; let expr = group_by_columns .iter() .map(|elem| (col(elem, &schema).unwrap(), (*elem).to_string())) @@ -650,7 +649,9 @@ pub(crate) fn assert_spill_count_metric( if expect_spill && spill_count == 0 { panic!("Expected spill but SpillCount metric not found or SpillCount was 0."); } else if !expect_spill && spill_count > 0 { - panic!("Expected no spill but found SpillCount metric with value greater than 0."); + panic!( + "Expected no spill but found SpillCount metric with value greater than 0." + ); } spill_count diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index fa8ea0b31c023..bf71053d6c852 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index aaf2d1b9bad4f..e49cffa89b04e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -18,7 +18,7 @@ use arrow::array::RecordBatch; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort::sort_batch; use test_utils::stagger_batch; @@ -209,8 +209,8 @@ mod test { sort_keys_set: vec![vec!["b".to_string()]], }; - let mut gen = DatasetGenerator::new(config); - let datasets = gen.generate().unwrap(); + let mut data_gen = DatasetGenerator::new(config); + let datasets = data_gen.generate().unwrap(); // Should Generate 2 datasets assert_eq!(datasets.len(), 2); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 1a8ef278cc299..430762b1c28db 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -19,9 +19,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{internal_datafusion_err, Result}; +use datafusion_common::{Result, internal_datafusion_err}; use datafusion_common_runtime::JoinSet; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs index 766e2bedd74c2..0d04e98536f2a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -17,7 +17,7 @@ use std::{collections::HashSet, str::FromStr}; -use rand::{rng, seq::SliceRandom, Rng}; +use rand::{Rng, rng, seq::SliceRandom}; /// Random aggregate query builder /// diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 3049631d4b3fe..92adda200d1a5 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::array::{cast::AsArray, Array, OffsetSizeTrait, RecordBatch}; +use arrow::array::{Array, OffsetSizeTrait, RecordBatch, cast::AsArray}; use datafusion::datasource::MemTable; use datafusion_common_runtime::JoinSet; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 171839b390ffa..a57095066ee12 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,19 +16,19 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, create_test_params, create_test_schema_2, + TestScalarUDF, create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, - is_table_same_after_sort, TestScalarUDF, + is_table_same_after_sort, }; use arrow::compute::SortOptions; -use datafusion_common::config::ConfigOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::equivalence::{ convert_to_orderings, convert_to_sort_exprs, }; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index a72a1558b2e41..2f67e211ce915 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -16,15 +16,15 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - apply_projection, create_random_schema, generate_table_for_eq_properties, - is_table_same_after_sort, TestScalarUDF, + TestScalarUDF, apply_projection, create_random_schema, + generate_table_for_eq_properties, is_table_same_after_sort, }; use arrow::compute::SortOptions; -use datafusion_common::config::ConfigOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 382c4da943219..1490eb08a0291 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -18,13 +18,13 @@ use std::sync::Arc; use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - TestScalarUDF, + TestScalarUDF, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, }; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index be35ddca8f02d..580a226721083 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -20,21 +20,21 @@ use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn, SortOptions}; +use arrow::compute::{SortColumn, SortOptions, lexsort_to_indices, take_record_batch}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; +use datafusion_common::{Result, exec_err, internal_datafusion_err, plan_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_physical_expr::equivalence::{ - convert_to_orderings, EquivalenceClass, ProjectionMapping, + EquivalenceClass, ProjectionMapping, convert_to_orderings, }; use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_plan::expressions::{col, Column}; +use datafusion_physical_plan::expressions::{Column, col}; use itertools::izip; use rand::prelude::*; diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index e8ff1ccf06704..ce422494db101 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -38,8 +38,8 @@ use datafusion::physical_plan::joins::{ }; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{NullEquality, ScalarValue}; -use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr::expressions::Literal; use itertools::Itertools; use rand::Rng; @@ -91,484 +91,564 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { #[tokio::test] async fn test_inner_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Inner, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Inner, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Left, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Left, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Right, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Right, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Full, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::Full, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[NljHj, HjSmj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::LeftMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } // todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support #[tokio::test] async fn test_right_mark_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_mark_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000), - make_staggered_batches_i32(1000), - JoinType::RightMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Inner, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Inner, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Left, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Left, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Right, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Right, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Full, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::Full, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[NljHj, HjSmj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::LeftMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } // todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support #[tokio::test] async fn test_right_mark_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_mark_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000), - make_staggered_batches_binary(1000), - JoinType::RightMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } type JoinFilterBuilder = Box, Arc) -> JoinFilter>; @@ -841,7 +921,9 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); + println!( + "Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}" + ); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -892,10 +974,18 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); + let err_msg_rowcnt = format!( + "NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}" + ); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); + if nlj_rows == 0 && hj_rows == 0 { + // both joins returned no rows, skip content comparison + continue; + } - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); + let err_msg_contents = format!( + "NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}" + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -913,10 +1003,16 @@ impl JoinFuzzTestCase { } if join_tests.contains(&HjSmj) { - let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + let err_msg_row_cnt = format!( + "HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", + &batch_size + ); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); - let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); + let err_msg_contents = format!( + "SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", + &batch_size + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows if smj_rows > 0 || hj_rows > 0 { @@ -1031,7 +1127,7 @@ impl JoinFuzzTestCase { /// Return randomly sized record batches with: /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns -fn make_staggered_batches_i32(len: usize) -> Vec { +fn make_staggered_batches_i32(len: usize, with_extra_column: bool) -> Vec { let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; @@ -1047,14 +1143,18 @@ fn make_staggered_batches_i32(len: usize) -> Vec { let input3 = Int32Array::from_iter_values(input3); let input4 = Int32Array::from_iter_values(input4); - // split into several record batches - let batch = RecordBatch::try_from_iter(vec![ + let mut columns = vec![ ("a", Arc::new(input1) as ArrayRef), ("b", Arc::new(input2) as ArrayRef), ("x", Arc::new(input3) as ArrayRef), - ("y", Arc::new(input4) as ArrayRef), - ]) - .unwrap(); + ]; + + if with_extra_column { + columns.push(("y", Arc::new(input4) as ArrayRef)); + } + + // split into several record batches + let batch = RecordBatch::try_from_iter(columns).unwrap(); // use a random number generator to pick a random sized output stagger_batch_with_seed(batch, 42) @@ -1070,7 +1170,10 @@ fn rand_bytes(rng: &mut R, min: usize, max: usize) -> Vec { /// Return randomly sized record batches with: /// two sorted binary columns 'a', 'b' (lexicographically) as join columns /// two random binary columns 'x', 'y' as other columns -fn make_staggered_batches_binary(len: usize) -> Vec { +fn make_staggered_batches_binary( + len: usize, + with_extra_column: bool, +) -> Vec { let mut rng = rand::rng(); // produce (a,b) pairs then sort lexicographically so SMJ has naturally sorted keys @@ -1088,13 +1191,17 @@ fn make_staggered_batches_binary(len: usize) -> Vec { let x = BinaryArray::from_iter_values(input3.iter()); let y = BinaryArray::from_iter_values(input4.iter()); - let batch = RecordBatch::try_from_iter(vec![ + let mut columns = vec![ ("a", Arc::new(a) as ArrayRef), ("b", Arc::new(b) as ArrayRef), ("x", Arc::new(x) as ArrayRef), - ("y", Arc::new(y) as ArrayRef), - ]) - .unwrap(); + ]; + + if with_extra_column { + columns.push(("y", Arc::new(y) as ArrayRef)); + } + + let batch = RecordBatch::try_from_iter(columns).unwrap(); // preserve your existing randomized partitioning stagger_batch_with_seed(batch, 42) diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 4c5ebf0402414..1c5741e7a21b3 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use std::sync::Arc; use test_utils::stagger_batch; diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index b92dec64e3f19..59430a98cc4b4 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -27,7 +27,7 @@ use arrow::{ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 9e2fd170f7f0c..edb53df382c62 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -15,20 +15,26 @@ // specific language governing permissions and limitations // under the License. +#[expect(clippy::needless_pass_by_value)] mod aggregate_fuzz; mod distinct_count_string_fuzz; +#[expect(clippy::needless_pass_by_value)] mod join_fuzz; mod merge_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_query_fuzz; mod topk_filter_pushdown; mod aggregation_fuzzer; +#[expect(clippy::needless_pass_by_value)] mod equivalence; mod pruning; mod limit_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index f8bd4dbc1a768..8a84e4c5d1814 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -29,9 +29,9 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::{collect, filter::FilterExec, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect, filter::FilterExec}; use itertools::Itertools; -use object_store::{memory::InMemory, path::Path, ObjectStore, PutPayload}; +use object_store::{ObjectStore, PutPayload, memory::InMemory, path::Path}; use parquet::{ arrow::ArrowWriter, file::properties::{EnabledStatistics, WriterProperties}, @@ -276,13 +276,12 @@ async fn execute_with_predicate( ctx: &SessionContext, ) -> Vec { let parquet_source = if prune_stats { - ParquetSource::default().with_predicate(predicate.clone()) + ParquetSource::new(schema.clone()).with_predicate(predicate.clone()) } else { - ParquetSource::default() + ParquetSource::new(schema.clone()) }; let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("memory://").unwrap(), - schema.clone(), Arc::new(parquet_source), ) .with_file_group( diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index 45dba5f7864b1..22b145f5095a7 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -19,23 +19,23 @@ use std::sync::Arc; use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ - ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, - Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType, + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, }; use arrow_schema::{ - DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, - DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; -use rand::{rng, rngs::StdRng, Rng, SeedableRng}; +use datafusion_common::{Result, arrow_datafusion_err}; +use rand::{Rng, SeedableRng, rng, rngs::StdRng}; use test_utils::array_gen::{ BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator, diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 28d28a6622a76..0d8a066d432dd 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::{ - array::{as_string_array, ArrayRef, Int32Array, StringArray}, + array::{ArrayRef, Int32Array, StringArray, as_string_array}, compute::SortOptions, record_batch::RecordBatch, }; @@ -28,7 +28,7 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; use datafusion_execution::memory_pool::GreedyMemoryPool; diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 99b20790fc46b..c424a314270c6 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -20,34 +20,33 @@ mod sp_repartition_fuzz_tests { use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; - use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; + use arrow::compute::{SortColumn, SortOptions, concat_batches, lexsort}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ - collect, + ExecutionPlan, Partitioning, collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, repartition::RepartitionExec, sorts::sort_preserving_merge::SortPreservingMergeExec, sorts::streaming_merge::StreamingMergeBuilder, stream::RecordBatchStreamAdapter, - ExecutionPlan, Partitioning, }; use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::ConstExpr; use datafusion_physical_expr::equivalence::{ EquivalenceClass, EquivalenceProperties, }; - use datafusion_physical_expr::expressions::{col, Column}; - use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; use itertools::izip; - use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom}; // Generate a schema which consists of 6 columns (a, b, c, d, e, f) fn create_test_schema() -> Result { diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 2ce7db3ea4bc7..376306f3e0659 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -24,24 +24,22 @@ use arrow::array::RecordBatch; use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{instant::Instant, Result}; +use datafusion_common::{Result, human_readable_size, instant::Instant}; use datafusion_execution::disk_manager::DiskManagerBuilder; -use datafusion_execution::memory_pool::{ - human_readable_size, MemoryPool, UnboundedMemoryPool, -}; +use datafusion_execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; use std::time::Duration; use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; -use rand::prelude::IndexedRandom; use rand::Rng; -use rand::{rngs::StdRng, SeedableRng}; +use rand::prelude::IndexedRandom; +use rand::{SeedableRng, rngs::StdRng}; use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; use super::aggregation_fuzzer::ColumnDescr; -use super::record_batch_generator::{get_supported_types_columns, RecordBatchGenerator}; +use super::record_batch_generator::{RecordBatchGenerator, get_supported_types_columns}; /// Entry point for executing the sort query fuzzer. /// @@ -177,16 +175,16 @@ impl SortQueryFuzzer { n_round: usize, n_query: usize, ) -> bool { - if let Some(time_limit) = self.time_limit { - if Instant::now().duration_since(start_time) > time_limit { - println!( - "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", - n_round * self.queries_per_round + n_query, - self.config_variations_per_query, - n_round - ); - return true; - } + if let Some(time_limit) = self.time_limit + && Instant::now().duration_since(start_time) > time_limit + { + println!( + "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", + n_round * self.queries_per_round + n_query, + self.config_variations_per_query, + n_round + ); + return true; } false } diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs index 6c1bd316cdd39..16481516e0bed 100644 --- a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -27,18 +27,18 @@ use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch} use arrow_schema::{DataType, Field, Schema}; use datafusion::common::Result; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; -use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_common::units::{KB, MB}; use datafusion_execution::memory_pool::{ FairSpillPool, MemoryConsumer, MemoryReservation, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -80,9 +80,9 @@ async fn test_sort_with_limited_memory() -> Result<()> { let total_spill_files_size = spill_count * record_batch_size; assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); Ok(()) } @@ -126,8 +126,8 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> } #[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -164,8 +164,8 @@ async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_c } #[tokio::test] -async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -356,16 +356,16 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<() let total_spill_files_size = spill_count * record_batch_size; assert!( - total_spill_files_size > pool_size, - "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", - ); + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); Ok(()) } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -398,8 +398,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -432,8 +432,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { @@ -466,8 +466,8 @@ async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_ } #[tokio::test] -async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch( -) -> Result<()> { +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch() +-> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 65a41d39d3c54..2ecfcd84aba98 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -18,19 +18,19 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, StringArray}; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::{SortOptions, concat_batches}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, + BoundedWindowAggExec, WindowAggExec, create_window_expr, schema_add_window_field, }; -use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; -use datafusion::physical_plan::{collect, InputOrderMode}; +use datafusion::physical_plan::{InputOrderMode, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; @@ -445,14 +445,14 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { - if !args.is_empty() { - // Do type coercion first argument - let a = args[0].clone(); - let dt = a.return_field(schema.as_ref()).unwrap(); - let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); - } + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn + && !args.is_empty() + { + // Do type coercion first argument + let a = args[0].clone(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } (window_fn.clone(), args, (*fn_name).to_string()) @@ -569,10 +569,11 @@ fn convert_bound_to_current_row_if_applicable( ) { match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { - if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.random_range(0..2) == 0 { - *bound = WindowFrameBound::CurrentRow; - } + if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) + && value == &zero + && rng.random_range(0..2) == 0 + { + *bound = WindowFrameBound::CurrentRow; } } _ => {} @@ -644,10 +645,8 @@ async fn run_window_test( ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. - if is_linear { - if let Some(ordering) = LexOrdering::new(sort_keys) { - exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; - } + if is_linear && let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -699,7 +698,9 @@ async fn run_window_test( // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + let err_msg = format!( + "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}" + ); // Below check makes sure that, streaming execution generates more chunks than the bulk execution. // Since algorithms and operators works on sliding windows in the streaming execution. // However, in the current test setup for some random generated window frame clauses: It is not guaranteed @@ -731,8 +732,12 @@ async fn run_window_test( .enumerate() { if !usual_line.eq(running_line) { - println!("Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}"); - println!("--------usual_formatted_sorted----------------running_formatted_sorted--------"); + println!( + "Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}" + ); + println!( + "--------usual_formatted_sorted----------------running_formatted_sorted--------" + ); for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index c9f33f6fdf0f4..48f0103113cf6 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -85,6 +85,7 @@ mod config_field { impl std::error::Error for E {} #[allow(dead_code)] + #[derive(Default)] struct S; impl std::str::FromStr for S { diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs index 7b157b707a6de..2c9fae20c8606 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs @@ -16,16 +16,14 @@ // under the License. use datafusion_common_runtime::SpawnedTask; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use sysinfo::{ProcessRefreshKind, ProcessesToUpdate, System}; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_execution::{ - memory_pool::{human_readable_size, FairSpillPool}, - runtime_env::RuntimeEnvBuilder, -}; +use datafusion_common::human_readable_size; +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; /// Measures the maximum RSS (in bytes) during the execution of an async task. RSS /// will be sampled every 7ms. @@ -40,7 +38,7 @@ use datafusion_execution::{ async fn measure_max_rss(f: F) -> (T, usize) where F: FnOnce() -> Fut, - Fut: std::future::Future, + Fut: Future, { // Initialize system information let mut system = System::new_all(); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 5d8a1d24181cb..c28d23ba0602b 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -39,19 +39,19 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; -use datafusion_common::{assert_contains, Result}; +use datafusion_catalog::streaming::StreamingTable; +use datafusion_common::{Result, assert_contains}; +use datafusion_execution::TaskContext; use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_plan::collect as collect_batches; use datafusion_physical_plan::common::collect; use datafusion_physical_plan::spill::get_record_batch_memory_size; @@ -604,8 +604,8 @@ async fn test_disk_spill_limit_reached() -> Result<()> { let err = df.collect().await.unwrap_err(); assert_contains!( - err.to_string(), - "The used disk space during the spilling process has exceeded the allowable limit" + err.to_string(), + "The used disk space during the spilling process has exceeded the allowable limit" ); Ok(()) @@ -977,11 +977,13 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![[ - PhysicalSortExpr::new(col("a", &schema).unwrap(), options), - PhysicalSortExpr::new(col("b", &schema).unwrap(), options), - ] - .into()]; + let sort_information = vec![ + [ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into(), + ]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -1057,7 +1059,7 @@ fn make_dict_batches() -> Vec { let batch_size = 50; let mut i = 0; - let gen = std::iter::from_fn(move || { + let batch_gen = std::iter::from_fn(move || { // create values like // 0000000001 // 0000000002 @@ -1080,7 +1082,7 @@ fn make_dict_batches() -> Vec { let num_batches = 5; - let batches: Vec<_> = gen.take(num_batches).collect(); + let batches: Vec<_> = batch_gen.take(num_batches).collect(); batches.iter().enumerate().for_each(|(i, batch)| { println!("Dict batch[{i}] size is: {}", batch.get_array_memory_size()); diff --git a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs index a7af2f01d1cc9..b21bffebaf95e 100644 --- a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs +++ b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs @@ -25,7 +25,7 @@ use datafusion::{ use datafusion_catalog::MemTable; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_physical_plan::{repartition::RepartitionExec, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlanProperties, repartition::RepartitionExec}; use futures::TryStreamExt; use itertools::Itertools; @@ -45,11 +45,14 @@ async fn test_repartition_memory_limit() { .with_batch_size(32) .with_target_partitions(2); let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); - let batches = vec![RecordBatch::try_from_iter(vec![( - "c1", - Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) as ArrayRef, - )]) - .unwrap()]; + let batches = vec![ + RecordBatch::try_from_iter(vec![( + "c1", + Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) + as ArrayRef, + )]) + .unwrap(), + ]; let table = Arc::new(MemTable::try_new(batches[0].schema(), vec![batches]).unwrap()); ctx.register_table("t", table).unwrap(); let plan = ctx diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..6466e9ad96d17 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -27,17 +27,16 @@ use arrow::datatypes::{ DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; +use datafusion_common::tree_node::TransformedResult; +use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, ScalarUDF, + TableSource, WindowUDF, col, lit, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -45,6 +44,7 @@ use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use chrono::DateTime; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees; use datafusion_functions::datetime; #[cfg(test)] @@ -304,8 +304,6 @@ fn test_inequalities_non_null_bounded() { ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt(lit(0)), false), @@ -337,7 +335,7 @@ fn test_inequalities_non_null_bounded() { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").gt(lit(2)), @@ -348,16 +346,20 @@ fn test_inequalities_non_null_bounded() { col("x").not_between(lit(3), lit(10)), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } -fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) -where +fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], +) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -365,9 +367,11 @@ where ); } } -fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { +fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 3a1f06656236c..31ec6efd19510 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -20,7 +20,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::SystemTime; -use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; +use arrow::array::{ArrayRef, Int8Array, Int64Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use datafusion::datasource::listing::PartitionedFile; @@ -31,8 +31,8 @@ use datafusion::datasource::physical_plan::{ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::SessionContext; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use bytes::Bytes; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -44,9 +44,9 @@ use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::AsyncFileReader; -use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; @@ -80,7 +80,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { .collect(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(file_schema.clone()) // prepare the scan .with_parquet_file_reader_factory(Arc::new( InMemoryParquetFileReaderFactory(Arc::clone(&in_memory_object_store)), @@ -89,7 +89,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { let base_config = FileScanConfigBuilder::new( // just any url that doesn't point to in memory object store ObjectStoreUrl::local_filesystem(), - file_schema, source, ) .with_file_group(file_group) diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs index 09b93f06ce85d..8b3170e367457 100644 --- a/datafusion/core/tests/parquet/encryption.rs +++ b/datafusion/core/tests/parquet/encryption.rs @@ -25,11 +25,11 @@ use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::listing::ListingOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_common::config::{EncryptionFactoryOptions, TableParquetOptions}; -use datafusion_common::{assert_batches_sorted_eq, exec_datafusion_err, DataFusionError}; +use datafusion_common::{DataFusionError, assert_batches_sorted_eq, exec_datafusion_err}; use datafusion_datasource_parquet::ParquetFormat; use datafusion_execution::parquet_encryption::EncryptionFactory; -use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::file::column_crypto_metadata::ColumnCryptoMetaData; @@ -54,6 +54,7 @@ async fn read_parquet_test_data<'a, T: Into>( .unwrap() } +#[expect(clippy::needless_pass_by_value)] pub fn write_batches( path: PathBuf, props: WriterProperties, diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs new file mode 100644 index 0000000000000..515422ed750ef --- /dev/null +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{RecordBatch, record_batch}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use bytes::{BufMut, BytesMut}; +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, +}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_datasource::ListingTableUrl; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{self, Column}; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, + PhysicalExprAdapterFactory, +}; +use object_store::{ObjectStore, memory::InMemory, path::Path}; +use parquet::arrow::ArrowWriter; + +async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store.put(&Path::from(path), data.into()).await.unwrap(); +} + +// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with +// the default value for the field type: +// - Int64 columns are filled with `1` +// - Utf8 columns are filled with `'b'` +#[derive(Debug)] +struct CustomPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + Arc::new(CustomPhysicalExprAdapter { + logical_file_schema: Arc::clone(&logical_file_schema), + physical_file_schema: Arc::clone(&physical_file_schema), + inner: Arc::new(DefaultPhysicalExprAdapter::new( + logical_file_schema, + physical_file_schema, + )), + }) + } +} + +#[derive(Debug, Clone)] +struct CustomPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + expr = expr + .transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let field_name = column.name(); + if self + .physical_file_schema + .field_with_name(field_name) + .ok() + .is_none() + { + let field = self + .logical_file_schema + .field_with_name(field_name) + .map_err(|_| { + DataFusionError::Plan(format!( + "Field '{field_name}' not found in logical file schema", + )) + })?; + // If the field does not exist, create a default value expression + // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests + let default_value = match field.data_type() { + DataType::Int64 => ScalarValue::Int64(Some(1)), + DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), + _ => unimplemented!( + "Unsupported data type: {}", + field.data_type() + ), + }; + return Ok(Transformed::yes(Arc::new( + expressions::Literal::new(default_value), + ))); + } + } + + Ok(Transformed::no(expr)) + }) + .data()?; + self.inner.rewrite(expr) + } +} + +#[tokio::test] +async fn test_custom_schema_adapter_and_custom_expression_adapter() { + let batch = + record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch, store.clone(), path).await; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), + ])); + + let mut cfg = SessionConfig::new() + // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + assert!( + !ctx.state() + .config_mut() + .options_mut() + .execution + .collect_statistics + ); + assert!(!ctx.state().config().collect_statistics()); + + // Test with DefaultPhysicalExprAdapterFactory - missing columns are filled with NULL + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Test with a custom physical expr adapter + // PhysicalExprAdapterFactory now handles both predicates AND projections + // CustomPhysicalExprAdapterFactory fills missing columns with 'b' for Utf8 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") + .await + .unwrap() + .collect() + .await + .unwrap(); + // With CustomPhysicalExprAdapterFactory, missing column c2 is filled with 'b' + // in both the predicate (c2 = 'b' becomes 'b' = 'b' -> true) and the projection + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| b | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); +} + +/// Test demonstrating how to implement a custom PhysicalExprAdapterFactory +/// that fills missing columns with non-null default values. +/// +/// PhysicalExprAdapterFactory rewrites expressions to use literals for +/// missing columns, handling schema evolution efficiently at planning time. +#[tokio::test] +async fn test_physical_expr_adapter_with_non_null_defaults() { + // File only has c1 column + let batch = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "defaults_test.parquet").await; + + // Table schema has additional columns c2 (Utf8) and c3 (Int64) that don't exist in file + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), // type differs from file (Int32 vs Int64) + Field::new("c2", DataType::Utf8, true), // missing from file + Field::new("c3", DataType::Int64, true), // missing from file + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // CustomPhysicalExprAdapterFactory fills: + // - missing Utf8 columns with 'b' + // - missing Int64 columns with 1 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + // Query all columns - missing columns should have default values + let batches = ctx + .sql("SELECT c1, c2, c3 FROM t ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // c1 is cast from Int32 to Int64, c2 defaults to 'b', c3 defaults to 1 + let expected = [ + "+----+----+----+", + "| c1 | c2 | c3 |", + "+----+----+----+", + "| 10 | b | 1 |", + "| 20 | b | 1 |", + "| 30 | b | 1 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work with default values + // c3 = 1 should match all rows since default is 1 + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + // c3 = 999 should match no rows + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 999") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + assert_batches_eq!(expected, &batches); +} + +/// Test demonstrating that a single PhysicalExprAdapterFactory instance can be +/// reused across multiple ListingTable instances. +/// +/// This addresses the concern: "This is important for ListingTable. A test for +/// ListingTable would add assurance that the functionality is retained [i.e. we +/// can re-use a PhysicalExprAdapterFactory]" +#[tokio::test] +async fn test_physical_expr_adapter_factory_reuse_across_tables() { + // Create two different parquet files with different schemas + // File 1: has column c1 only + let batch1 = record_batch!(("c1", Int32, [1, 2, 3])).unwrap(); + // File 2: has column c1 only but different data + let batch2 = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + + // Write files to different paths + write_parquet(batch1, store.clone(), "table1/data.parquet").await; + write_parquet(batch2, store.clone(), "table2/data.parquet").await; + + // Table schema has additional columns that don't exist in files + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), // missing from files + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // Create ONE factory instance wrapped in Arc - this will be REUSED + let factory: Arc = + Arc::new(CustomPhysicalExprAdapterFactory); + + // Create ListingTable 1 using the shared factory + let listing_table_config1 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table1/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Clone the Arc, not create new factory + + let table1 = ListingTable::try_new(listing_table_config1).unwrap(); + ctx.register_table("t1", Arc::new(table1)).unwrap(); + + // Create ListingTable 2 using the SAME factory instance + let listing_table_config2 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table2/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Reuse same factory + + let table2 = ListingTable::try_new(listing_table_config2).unwrap(); + ctx.register_table("t2", Arc::new(table2)).unwrap(); + + // Verify table 1 works correctly with the shared factory + // CustomPhysicalExprAdapterFactory fills missing Utf8 columns with 'b' + let batches = ctx + .sql("SELECT c1, c2 FROM t1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | b |", + "| 2 | b |", + "| 3 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify table 2 also works correctly with the SAME shared factory + let batches = ctx + .sql("SELECT c1, c2 FROM t2 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 10 | b |", + "| 20 | b |", + "| 30 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work on both tables with the shared factory + let batches = ctx + .sql("SELECT c1 FROM t1 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 1 |", + "| 2 |", + "| 3 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = ctx + .sql("SELECT c1 FROM t2 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index 5135f956852c3..0c02c8fe523dc 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -21,7 +21,7 @@ use std::path::Path; use std::sync::Arc; use crate::parquet::utils::MetricsFinder; -use crate::parquet::{create_data_batch, Scenario}; +use crate::parquet::{Scenario, create_data_batch}; use arrow::datatypes::SchemaRef; use arrow::util::pretty::pretty_format_batches; @@ -29,17 +29,17 @@ use datafusion::common::Result; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::prelude::SessionContext; -use datafusion_common::{assert_contains, DFSchema}; +use datafusion_common::{DFSchema, assert_contains}; use datafusion_datasource_parquet::{ParquetAccessPlan, RowGroupAccess}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{col, lit, Expr}; -use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::properties::WriterProperties; use tempfile::NamedTempFile; @@ -257,7 +257,10 @@ async fn bad_selection() { .await .unwrap_err(); let err_string = err.to_string(); - assert_contains!(&err_string, "Internal error: Invalid ParquetAccessPlan Selection. Row group 0 has 5 rows but selection only specifies 4 rows"); + assert_contains!( + &err_string, + "Row group 0 has 5 rows but selection only specifies 4 rows." + ); } /// Return a RowSelection of 1 rows from a row group of 5 rows @@ -355,11 +358,11 @@ impl TestFull { let source = if let Some(predicate) = predicate { let df_schema = DFSchema::try_from(schema.clone())?; let predicate = ctx.create_physical_expr(predicate, &df_schema)?; - Arc::new(ParquetSource::default().with_predicate(predicate)) + Arc::new(ParquetSource::new(schema.clone()).with_predicate(predicate)) } else { - Arc::new(ParquetSource::default()) + Arc::new(ParquetSource::new(schema.clone())) }; - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) + let config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 64ee92eda2545..fdefdafa00aa4 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -18,31 +18,30 @@ use std::fs; use std::sync::Arc; +use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; -use datafusion_common::stats::Precision; use datafusion_common::DFSchema; +use datafusion_common::stats::Precision; +use datafusion_execution::cache::DefaultListFilesCache; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit::{ - DefaultFileStatisticsCache, DefaultListFilesCache, -}; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion::datasource::physical_plan::FileScanConfig; use datafusion_common::config::ConfigOptions; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::filter::FilterExec; use tempfile::tempdir; #[tokio::test] @@ -127,8 +126,9 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec1.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Byte size is absent because we cannot estimate the output size + // of the Arrow data since there are variable length columns. + Precision::Absent, ); assert_eq!(get_static_cache_size(&state1), 1); @@ -142,8 +142,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec2.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); assert_eq!(get_static_cache_size(&state2), 1); @@ -157,8 +157,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec3.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); // List same file no increase assert_eq!(get_static_cache_size(&state1), 1); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 966f251613979..e3a191ee9ade2 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -31,7 +31,7 @@ use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion::prelude::{ - col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, + Expr, ParquetReadOptions, SessionContext, col, lit, lit_timestamp_nano, }; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; @@ -636,6 +636,27 @@ async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { config.options_mut().execution.parquet.pushdown_filters = true; let ctx = SessionContext::new_with_config(config); // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 7, // reads more than necessary from the cache as then another bitmap is applied + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_default_selections_only() +-> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // forcing filter selections minimizes the number of rows read from the cache + config + .options_mut() + .execution + .parquet + .force_filter_selections = true; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled PredicateCacheTest { expected_inner_records: 8, expected_records: 4, diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45eadd..35b5918d9e8bf 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,12 +19,12 @@ use crate::parquet::utils::MetricsFinder; use arrow::{ array::{ - make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, Decimal128Array, + DictionaryArray, FixedSizeBinaryArray, Float64Array, Int8Array, Int16Array, + Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, + make_array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -32,7 +32,7 @@ use arrow::{ }; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ - datasource::{provider_as_source, TableProvider}, + datasource::{TableProvider, provider_as_source}, physical_plan::metrics::MetricsSet, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; @@ -46,13 +46,13 @@ use tempfile::NamedTempFile; mod custom_reader; #[cfg(feature = "parquet_encryption")] mod encryption; +mod expr_adapter; mod external_access_plan; mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; mod schema; -mod schema_adapter; mod schema_coercion; mod utils; @@ -147,15 +147,14 @@ impl TestOutput { for metric in self.parquet_metrics.iter() { let metric = metric.as_ref(); - if metric.value().name() == metric_name { - if let MetricValue::PruningMetrics { + if metric.value().name() == metric_name + && let MetricValue::PruningMetrics { pruning_metrics, .. } = metric.value() - { - total_pruned += pruning_metrics.pruned(); - total_matched += pruning_metrics.matched(); - found = true; - } + { + total_pruned += pruning_metrics.pruned(); + total_matched += pruning_metrics.matched(); + found = true; } } @@ -652,6 +651,7 @@ fn make_date_batch(offset: Duration) -> RecordBatch { /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_bytearray_batch( name: &str, string_values: Vec<&str>, @@ -707,6 +707,7 @@ fn make_bytearray_batch( /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_names_batch(name: &str, service_name_values: Vec<&str>) -> RecordBatch { let num_rows = service_name_values.len(); let name: StringArray = std::iter::repeat_n(Some(name), num_rows).collect(); @@ -791,6 +792,7 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch { let keys = Int32Array::from_iter(0..strings.len() as i32); let small_keys = Int16Array::from_iter(0..strings.len() as i16); @@ -839,6 +841,7 @@ fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 27bee10234b57..17392974b63a8 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -21,25 +21,25 @@ use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; use arrow::array::RecordBatch; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion::execution::context::SessionState; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::SessionContext; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_expr::create_physical_expr; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use futures::StreamExt; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::path::Path; async fn get_parquet_exec( state: &SessionState, @@ -81,12 +81,12 @@ async fn get_parquet_exec( let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) .with_predicate(predicate) .with_enable_page_index(true) .with_pushdown_filters(pushdown_filters), ); - let base_config = FileScanConfigBuilder::new(object_store_url, schema, source) + let base_config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs deleted file mode 100644 index 40fc6176e212b..0000000000000 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ /dev/null @@ -1,553 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use arrow::array::{record_batch, RecordBatch, RecordBatchOptions}; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; -use bytes::{BufMut, BytesMut}; -use datafusion::assert_batches_eq; -use datafusion::common::Result; -use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableConfigExt, -}; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::DataFusionError; -use datafusion_common::{ColumnStatistics, ScalarValue}; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, -}; -use datafusion_datasource::ListingTableUrl; -use datafusion_datasource_parquet::source::ParquetSource; -use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr_adapter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, - PhysicalExprAdapterFactory, -}; -use itertools::Itertools; -use object_store::{memory::InMemory, path::Path, ObjectStore}; -use parquet::arrow::ArrowWriter; - -async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { - let mut out = BytesMut::new().writer(); - { - let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - } - let data = out.into_inner().freeze(); - store.put(&Path::from(path), data.into()).await.unwrap(); -} - -#[derive(Debug)] -struct CustomSchemaAdapterFactory; - -impl SchemaAdapterFactory for CustomSchemaAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(CustomSchemaAdapter { - logical_file_schema: projected_table_schema, - }) - } -} - -#[derive(Debug)] -struct CustomSchemaAdapter { - logical_file_schema: SchemaRef, -} - -impl SchemaAdapter for CustomSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - for (idx, field) in file_schema.fields().iter().enumerate() { - if field.name() == self.logical_file_schema.field(index).name() { - return Some(idx); - } - } - None - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let projection = (0..file_schema.fields().len()).collect_vec(); - Ok(( - Arc::new(CustomSchemaMapper { - logical_file_schema: Arc::clone(&self.logical_file_schema), - }), - projection, - )) - } -} - -#[derive(Debug)] -struct CustomSchemaMapper { - logical_file_schema: SchemaRef, -} - -impl SchemaMapper for CustomSchemaMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - let mut output_columns = - Vec::with_capacity(self.logical_file_schema.fields().len()); - for field in self.logical_file_schema.fields() { - if let Some(array) = batch.column_by_name(field.name()) { - output_columns.push(cast_with_options( - array, - field.data_type(), - &CastOptions::default(), - )?); - } else { - // Create a new array with the default value for the field type - let default_value = match field.data_type() { - DataType::Int64 => ScalarValue::Int64(Some(0)), - DataType::Utf8 => ScalarValue::Utf8(Some("a".to_string())), - _ => unimplemented!("Unsupported data type: {}", field.data_type()), - }; - output_columns - .push(default_value.to_array_of_size(batch.num_rows()).unwrap()); - } - } - let batch = RecordBatch::try_new_with_options( - Arc::clone(&self.logical_file_schema), - output_columns, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - ) - .unwrap(); - Ok(batch) - } - - fn map_column_statistics( - &self, - _file_col_statistics: &[ColumnStatistics], - ) -> Result> { - Ok(vec![ - ColumnStatistics::new_unknown(); - self.logical_file_schema.fields().len() - ]) - } -} - -// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with the default value for the field type -#[derive(Debug)] -struct CustomPhysicalExprAdapterFactory; - -impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { - fn create( - &self, - logical_file_schema: SchemaRef, - physical_file_schema: SchemaRef, - ) -> Arc { - Arc::new(CustomPhysicalExprAdapter { - logical_file_schema: Arc::clone(&logical_file_schema), - physical_file_schema: Arc::clone(&physical_file_schema), - inner: Arc::new(DefaultPhysicalExprAdapter::new( - logical_file_schema, - physical_file_schema, - )), - }) - } -} - -#[derive(Debug, Clone)] -struct CustomPhysicalExprAdapter { - logical_file_schema: SchemaRef, - physical_file_schema: SchemaRef, - inner: Arc, -} - -impl PhysicalExprAdapter for CustomPhysicalExprAdapter { - fn rewrite(&self, mut expr: Arc) -> Result> { - expr = expr - .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - let field_name = column.name(); - if self - .physical_file_schema - .field_with_name(field_name) - .ok() - .is_none() - { - let field = self - .logical_file_schema - .field_with_name(field_name) - .map_err(|_| { - DataFusionError::Plan(format!( - "Field '{field_name}' not found in logical file schema", - )) - })?; - // If the field does not exist, create a default value expression - // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests - let default_value = match field.data_type() { - DataType::Int64 => ScalarValue::Int64(Some(1)), - DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), - _ => unimplemented!( - "Unsupported data type: {}", - field.data_type() - ), - }; - return Ok(Transformed::yes(Arc::new( - expressions::Literal::new(default_value), - ))); - } - } - - Ok(Transformed::no(expr)) - }) - .data()?; - self.inner.rewrite(expr) - } - - fn with_partition_values( - &self, - partition_values: Vec<(FieldRef, ScalarValue)>, - ) -> Arc { - assert!( - partition_values.is_empty(), - "Partition values are not supported in this test" - ); - Arc::new(self.clone()) - } -} - -#[tokio::test] -async fn test_custom_schema_adapter_and_custom_expression_adapter() { - let batch = - record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch, store.clone(), path).await; - - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, false), - Field::new("c2", DataType::Utf8, true), - ])); - - let mut cfg = SessionConfig::new() - // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation - .with_collect_statistics(false) - .with_parquet_pruning(false) - .with_parquet_page_index_pruning(false); - cfg.options_mut().execution.parquet.pushdown_filters = true; - let ctx = SessionContext::new_with_config(cfg); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - assert!( - !ctx.state() - .config_mut() - .options_mut() - .execution - .collect_statistics - ); - assert!(!ctx.state().config().collect_statistics()); - - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)) - .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); - - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") - .await - .unwrap() - .collect() - .await - .unwrap(); - - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // Test using a custom schema adapter and no explicit physical expr adapter - // This should use the custom schema adapter both for projections and predicate pushdown - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'a'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| a | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // Do the same test but with a custom physical expr adapter - // Now the default schema adapter will be used for projections, but the custom physical expr adapter will be used for predicate pushdown - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); - - // If we use both then the custom physical expr adapter will be used for predicate pushdown and the custom schema adapter will be used for projections - let listing_table_config = - ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) - .infer_options(&ctx.state()) - .await - .unwrap() - .with_schema(table_schema.clone()) - .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)) - .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); - let table = ListingTable::try_new(listing_table_config).unwrap(); - ctx.deregister_table("t").unwrap(); - ctx.register_table("t", Arc::new(table)).unwrap(); - let batches = ctx - .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") - .await - .unwrap() - .collect() - .await - .unwrap(); - let expected = [ - "+----+----+", - "| c2 | c1 |", - "+----+----+", - "| a | 2 |", - "+----+----+", - ]; - assert_batches_eq!(expected, &batches); -} - -/// A test schema adapter factory that adds prefix to column names -#[derive(Debug)] -struct PrefixAdapterFactory { - prefix: String, -} - -impl SchemaAdapterFactory for PrefixAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(PrefixAdapter { - input_schema: projected_table_schema, - prefix: self.prefix.clone(), - }) - } -} - -/// A test schema adapter that adds prefix to column names -#[derive(Debug)] -struct PrefixAdapter { - input_schema: SchemaRef, - prefix: String, -} - -impl SchemaAdapter for PrefixAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.input_schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - for (file_idx, file_field) in file_schema.fields().iter().enumerate() { - if self.input_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); - } - } - - // Create a schema mapper that adds a prefix to column names - #[derive(Debug)] - struct PrefixSchemaMapping { - // Keep only the prefix field which is actually used in the implementation - prefix: String, - } - - impl SchemaMapper for PrefixSchemaMapping { - fn map_batch(&self, batch: RecordBatch) -> Result { - // Create a new schema with prefixed field names - let prefixed_fields: Vec = batch - .schema() - .fields() - .iter() - .map(|field| { - Field::new( - format!("{}{}", self.prefix, field.name()), - field.data_type().clone(), - field.is_nullable(), - ) - }) - .collect(); - let prefixed_schema = Arc::new(Schema::new(prefixed_fields)); - - // Create a new batch with the prefixed schema but the same data - let options = RecordBatchOptions::default(); - RecordBatch::try_new_with_options( - prefixed_schema, - batch.columns().to_vec(), - &options, - ) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - // For testing, just return the input statistics - Ok(stats.to_vec()) - } - } - - Ok(( - Arc::new(PrefixSchemaMapping { - prefix: self.prefix.clone(), - }), - projection, - )) - } -} - -#[test] -fn test_apply_schema_adapter_with_factory() { - // Create a schema - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create a parquet source - let source = ParquetSource::default(); - - // Create a file scan config with source that has a schema adapter factory - let factory = Arc::new(PrefixAdapterFactory { - prefix: "test_".to_string(), - }); - - let file_source = source.clone().with_schema_adapter_factory(factory).unwrap(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema.clone(), - file_source, - ) - .build(); - - // Apply schema adapter to a new source - let result_source = source.apply_schema_adapter(&config).unwrap(); - - // Verify the adapter was applied - assert!(result_source.schema_adapter_factory().is_some()); - - // Create adapter and test it produces expected schema - let adapter_factory = result_source.schema_adapter_factory().unwrap(); - let adapter = adapter_factory.create(schema.clone(), schema.clone()); - - // Create a dummy batch to test the schema mapping - let dummy_batch = RecordBatch::new_empty(schema.clone()); - - // Get the file schema (which is the same as the table schema in this test) - let (mapper, _) = adapter.map_schema(&schema).unwrap(); - - // Apply the mapping to get the output schema - let mapped_batch = mapper.map_batch(dummy_batch).unwrap(); - let output_schema = mapped_batch.schema(); - - // Check the column names have the prefix - assert_eq!(output_schema.field(0).name(), "test_id"); - assert_eq!(output_schema.field(1).name(), "test_name"); -} - -#[test] -fn test_apply_schema_adapter_without_factory() { - // Create a schema - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create a parquet source - let source = ParquetSource::default(); - - // Convert to Arc - let file_source: Arc = Arc::new(source.clone()); - - // Create a file scan config without a schema adapter factory - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema.clone(), - file_source, - ) - .build(); - - // Apply schema adapter function - should pass through the source unchanged - let result_source = source.apply_schema_adapter(&config).unwrap(); - - // Verify no adapter was applied - assert!(result_source.schema_adapter_factory().is_none()); -} diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 9be391a9108e6..6f7e2e328d0c3 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -18,16 +18,16 @@ use std::sync::Arc; use arrow::array::{ - types::Int32Type, ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, - StringArray, + ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; use datafusion::test::object_store::local_unpartitioned_file; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -62,14 +62,10 @@ async fn multi_parquet_coercion() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let source = Arc::new(ParquetSource::default()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .build(); + let source = Arc::new(ParquetSource::new(file_schema.clone())); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .build(); let parquet_exec = DataSourceExec::from_data_source(conf); @@ -122,11 +118,11 @@ async fn multi_parquet_coercion_projection() { ])); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_group(file_group) .with_projection_indices(Some(vec![1, 0, 2])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/parquet/utils.rs b/datafusion/core/tests/parquet/utils.rs index 24b6cadc148f8..e5e0026ec1f16 100644 --- a/datafusion/core/tests/parquet/utils.rs +++ b/datafusion/core/tests/parquet/utils.rs @@ -20,7 +20,7 @@ use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanVisitor, accept}; /// Find the metrics from the first DataSourceExec encountered in the plan #[derive(Debug)] @@ -47,13 +47,12 @@ impl MetricsFinder { impl ExecutionPlanVisitor for MetricsFinder { type Error = std::convert::Infallible; fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - self.metrics = data_source_exec.metrics(); - } + { + self.metrics = data_source_exec.metrics(); } // stop searching once we have found the metrics Ok(self.metrics.is_none()) diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index a79d743cb253d..1fdc0ae6c7f60 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -24,14 +24,15 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; +use datafusion_common::Result; use datafusion_common::cast::as_int64_array; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::PhysicalGroupBy; @@ -39,7 +40,6 @@ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; /// Mock data using a MemorySourceConfig which has an exact count statistic fn mock_data() -> Result> { diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 9c76f6ab6f58b..2fdfece2a86e7 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -29,18 +29,18 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::ExecutionPlan; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { @@ -191,7 +191,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { // should combine the Partial/Final AggregateExecs to the Single AggregateExec assert_optimized!( plan, - @ " + @ r" AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet " diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 5b7d9ac8fbe99..7cedaf86cb52f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -26,32 +26,33 @@ use crate::physical_optimizer::test_utils::{ sort_preserving_merge_exec, union_exec, }; -use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; +use arrow::array::{RecordBatch, UInt8Array, UInt64Array}; use arrow::compute::SortOptions; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; +use datafusion::datasource::MemTable; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::ScalarValue; +use datafusion_common::config::CsvOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, OrderingRequirements, PhysicalSortExpr, }; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -66,8 +67,8 @@ use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - displayable, DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, - Statistics, + DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, + displayable, }; use insta::Settings; @@ -229,8 +230,7 @@ fn parquet_exec_multiple_sorted( ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema())), ) .with_file_groups(vec![ FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), @@ -247,14 +247,19 @@ fn csv_exec() -> Arc { } fn csv_exec_with_sort(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -265,17 +270,22 @@ fn csv_exec_multiple() -> Arc { // Created a sorted parquet exec with multiple files fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file_groups(vec![ - FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), - FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), - ]) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file_groups(vec![ + FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), + FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), + ]) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -618,16 +628,13 @@ fn multi_hash_joins() -> Result<()> { assert_plan!(plan_distrib, @r" HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, // Should include 4 RepartitionExecs @@ -636,16 +643,13 @@ fn multi_hash_joins() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, }; @@ -690,16 +694,13 @@ fn multi_hash_joins() -> Result<()> { assert_plan!(plan_distrib, @r" HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@1, c@2)] HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); } @@ -710,16 +711,13 @@ fn multi_hash_joins() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@6, c@2)] RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet "); }, @@ -780,15 +778,12 @@ fn multi_joins_after_alias() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, c@2)] ProjectionExec: expr=[a@0 as a1, a@0 as a2] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -811,15 +806,12 @@ fn multi_joins_after_alias() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a2@1, c@2)] ProjectionExec: expr=[a@0 as a1, a@0 as a2] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -869,15 +861,12 @@ fn multi_joins_after_multi_alias() -> Result<()> { ProjectionExec: expr=[c1@0 as a] ProjectionExec: expr=[c@2 as c1] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); @@ -1098,21 +1087,17 @@ fn multi_hash_join_key_ordering() -> Result<()> { HashJoinExec: mode=Partitioned, join_type=Inner, on=[(B@2, b1@6), (C@3, c@2), (AA@1, a1@5)] ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet " ); let plan_sort = test_config.to_plan(filter_top_join, &SORT_DISTRIB_DISTRIB); @@ -1236,25 +1221,21 @@ fn reorder_join_keys_to_left_input() -> Result<()> { assert_eq!(captured_join_type, join_type.to_string()); insta::allow_duplicates! {insta::assert_snapshot!(modified_plan, @r" -HashJoinExec: mode=Partitioned, join_type=..., on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)] - ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)] - RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -");} + HashJoinExec: mode=Partitioned, join_type=..., on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)] + RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1368,25 +1349,21 @@ fn reorder_join_keys_to_right_input() -> Result<()> { let (_, plan_str) = hide_first(reordered.as_ref(), r"join_type=(\w+)", "join_type=..."); insta::allow_duplicates! {insta::assert_snapshot!(plan_str, @r" -HashJoinExec: mode=Partitioned, join_type=..., on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)] - ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)] - RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] - RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -");} + HashJoinExec: mode=Partitioned, join_type=..., on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)] + RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1447,52 +1424,46 @@ fn multi_smj_joins() -> Result<()> { // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=..., on=[(a@0, c@2)] - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin + // Since ordering of the left child is not preserved after SortMergeJoinExec // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases // when mode is Inner, Left, LeftSemi, LeftAnti // Similarly, since partitioning of the left side is not preserved // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test + // - We need to add one additional Hash Repartition after SortMergeJoinExec in contrast the test // cases when mode is Inner, Left, LeftSemi, LeftAnti _ => { assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=..., on=[(a@0, c@2)] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } } @@ -1503,55 +1474,48 @@ SortMergeJoin: join_type=..., on=[(a@0, c@2)] JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { // TODO(wiedld): show different test result if enforce distribution first. assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=..., on=[(a@0, c@2)] - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin + // Since ordering of the left child is not preserved after SortMergeJoinExec // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases // when mode is Inner, Left, LeftSemi, LeftAnti // Similarly, since partitioning of the left side is not preserved // when mode is Right, RightSemi, RightAnti, Full // - We need to add one additional Hash Repartition and Roundrobin repartition after - // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti + // SortMergeJoinExec in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti _ => { // TODO(wiedld): show different test result if enforce distribution first. assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=..., on=[(a@0, c@2)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } } @@ -1572,45 +1536,39 @@ SortMergeJoin: join_type=..., on=[(a@0, c@2)] JoinType::Inner | JoinType::Right => { // TODO(wiedld): show different test result if enforce sorting first. assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=..., on=[(b1@6, c@2)] - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs JoinType::Left | JoinType::Full => { // TODO(wiedld): show different test result if enforce sorting first. assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=..., on=[(b1@6, c@2)] - SortExec: expr=[b1@6 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortExec: expr=[b1@6 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // this match arm cannot be reached _ => unreachable!() @@ -1623,47 +1581,40 @@ SortMergeJoin: join_type=..., on=[(b1@6, c@2)] JoinType::Inner | JoinType::Right => { // TODO(wiedld): show different test result if enforce distribution first. assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=..., on=[(b1@6, c@2)] - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs JoinType::Left | JoinType::Full => { // TODO(wiedld): show different test result if enforce distribution first. assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=..., on=[(b1@6, c@2)] - RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b1@6 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - SortMergeJoin: join_type=..., on=[(a@0, b1@1)] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@6 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); } // this match arm cannot be reached _ => unreachable!() @@ -1731,50 +1682,48 @@ fn smj_join_key_ordering() -> Result<()> { // Only two RepartitionExecs added let plan_distrib = test_config.to_plan(join.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] - SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] - ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] - AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] - RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@1 as a2, b@0 as b2] - AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] - RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(join, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] - RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC, a3@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] - ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] - AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] - RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC, a2@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@1 as a2, b@0 as b2] - AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] - RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + "); Ok(()) } @@ -1807,10 +1756,10 @@ fn merge_does_not_need_sort() -> Result<()> { let plan_distrib = test_config.to_plan(exec.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + CoalesceBatchesExec: target_batch_size=4096 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: // @@ -1821,11 +1770,11 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(exec, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - CoalesceBatchesExec: target_batch_size=4096 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + CoalesceBatchesExec: target_batch_size=4096 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -2002,11 +1951,11 @@ fn repartition_sorted_limit() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // data is sorted so can't repartition here let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2031,12 +1980,12 @@ fn repartition_sorted_limit_with_filter() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [c@2 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [c@2 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // We can use repartition here, ordering requirement by SortRequiredExec // is still satisfied. let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); @@ -2057,19 +2006,19 @@ fn repartition_ignores_limit() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // repartition should happen prior to the filter to maximize parallelism // Expect no repartition to happen for local limit (DataSourceExec) @@ -2087,13 +2036,13 @@ fn repartition_ignores_union() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect no repartition of DataSourceExec let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2116,9 +2065,9 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2144,9 +2093,9 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [c@2 ASC] - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // should not sort (as the data was already sorted) // should not repartition, since increased parallelism is not beneficial for SortPReservingMerge @@ -2154,10 +2103,10 @@ SortPreservingMergeExec: [c@2 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2182,11 +2131,11 @@ fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [c@2 ASC] - UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // // should not repartition / sort (as the data was already sorted) @@ -2194,12 +2143,12 @@ SortPreservingMergeExec: [c@2 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2226,11 +2175,11 @@ fn repartition_does_not_destroy_sort() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [d@3 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet -"); + SortRequiredExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); // during repartitioning ordering is preserved let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2266,13 +2215,13 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -UnionExec - SortRequiredExec: [c@2 ASC] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + SortRequiredExec: [c@2 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // union input 1: no repartitioning // union input 2: should repartition // @@ -2309,23 +2258,23 @@ fn repartition_transitively_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [sum@0 ASC] - SortExec: expr=[sum@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@0 + b@1 as sum] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [sum@0 ASC] + SortExec: expr=[sum@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[sum@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@0 + b@1 as sum] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[sum@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Since this projection is not trivial, increasing parallelism is beneficial Ok(()) @@ -2357,10 +2306,10 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortRequiredExec: [c@2 ASC] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortRequiredExec: [c@2 ASC] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Since this projection is trivial, increasing parallelism is not beneficial let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); @@ -2394,10 +2343,10 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Since this projection is trivial, increasing parallelism is not beneficial let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -2419,12 +2368,12 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input to the sort (as it can benefit from additional parallelism) @@ -2432,12 +2381,12 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input of the filter (as it can benefit from additional parallelism) Ok(()) @@ -2468,13 +2417,13 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Expect repartition on the input to the sort (as it can benefit from additional parallelism) // repartition is lowest down @@ -2483,13 +2432,13 @@ SortPreservingMergeExec: [a@0 ASC] let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -2509,11 +2458,11 @@ fn parallelization_single_partition() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2521,11 +2470,11 @@ AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2557,10 +2506,10 @@ fn parallelization_multiple_files() -> Result<()> { test_config_concurrency_3.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_3_distrib, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let plan_3_sort = test_config_concurrency_3.to_plan(plan.clone(), &SORT_DISTRIB_DISTRIB); assert_plan!(plan_3_distrib, plan_3_sort); @@ -2570,10 +2519,10 @@ SortRequiredExec: [a@0 ASC] test_config_concurrency_8.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_8_distrib, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let plan_8_sort = test_config_concurrency_8.to_plan(plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_8_distrib, plan_8_sort); @@ -2597,11 +2546,15 @@ fn parallelization_compressed_csv() -> Result<()> { for compression_type in compression_types { let plan = aggregate_exec_with_alias( DataSourceExec::from_data_source( - FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_file_compression_type(compression_type) .build(), @@ -2617,21 +2570,21 @@ fn parallelization_compressed_csv() -> Result<()> { // Compressed files cannot be partitioned assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); } else { // Uncompressed files can be partitioned assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); } let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); @@ -2656,11 +2609,11 @@ fn parallelization_two_partitions() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Plan already has two partitions let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2668,11 +2621,11 @@ AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] // Test: with csv let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Plan already has two partitions let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2696,11 +2649,11 @@ fn parallelization_two_partitions_into_four() -> Result<()> { // Multiple source files split across partitions assert_plan!(plan_parquet_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Multiple source files split across partitions let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2709,11 +2662,11 @@ AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); // Multiple source files split across partitions assert_plan!(plan_csv_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Multiple source files split across partitions let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2738,11 +2691,11 @@ fn parallelization_sorted_limit() -> Result<()> { let plan_parquet_distrib = test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // data is sorted so can't repartition here // Doesn't parallelize for SortExec without preserve_partitioning let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); @@ -2752,11 +2705,11 @@ GlobalLimitExec: skip=0, fetch=100 let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - LocalLimitExec: fetch=100 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // data is sorted so can't repartition here // Doesn't parallelize for SortExec without preserve_partitioning let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); @@ -2787,14 +2740,14 @@ fn parallelization_limit_with_filter() -> Result<()> { // SortExec doesn't benefit from input partitioning assert_plan!(plan_parquet_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2805,14 +2758,14 @@ GlobalLimitExec: skip=0, fetch=100 // SortExec doesn't benefit from input partitioning assert_plan!(plan_csv_distrib, @r" -GlobalLimitExec: skip=0, fetch=100 - CoalescePartitionsExec - LocalLimitExec: fetch=100 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -2891,13 +2844,13 @@ fn parallelization_union_inputs() -> Result<()> { test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_parquet_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Union doesn't benefit from input partitioning - no parallelism let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -2906,13 +2859,13 @@ UnionExec let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_csv_distrib, @r" -UnionExec - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false -"); + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); // Union doesn't benefit from input partitioning - no parallelism let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_csv_distrib, plan_csv_sort); @@ -3118,9 +3071,9 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> // data should not be repartitioned / resorted assert_plan!(plan_parquet_distrib, @r" -ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_parquet_distrib, plan_parquet_sort); @@ -3153,18 +3106,18 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); assert_plan!(plan_csv, @r" -SortPreservingMergeExec: [c2@1 ASC] - ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false -"); + SortPreservingMergeExec: [c2@1 ASC] + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -ProjectionExec: expr=[a@0 as a2, c@2 as c2] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false -"); + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); // Expected Outcome: // data should not be repartitioned / resorted let plan_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); @@ -3180,21 +3133,21 @@ fn remove_redundant_roundrobins() -> Result<()> { let physical_plan = repartition_exec(filter_exec(repartition)); assert_plan!(physical_plan, @r" -RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - FilterExec: c@2 = 0 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3222,11 +3175,11 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { // This is still satisfied since, after filter that column is constant. assert_plan!(plan_distrib, @r" -CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3251,11 +3204,11 @@ fn preserve_ordering_through_repartition() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [d@3 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet -"); + SortPreservingMergeExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3279,23 +3232,23 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3314,11 +3267,11 @@ fn no_need_for_sort_after_filter() -> Result<()> { let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. @@ -3350,24 +3303,24 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. assert_plan!(plan_distrib, @r" -SortPreservingMergeExec: [a@0 ASC] - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_sort, @r" -SortExec: expr=[a@0 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - SortExec: expr=[a@0 ASC], preserve_partitioning=[true] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -3387,10 +3340,10 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3410,10 +3363,10 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { // Ordering requirement of sort required exec is NOT satisfied // by existing ordering at the source. assert_plan!(physical_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; @@ -3423,11 +3376,11 @@ SortRequiredExec: [a@0 ASC] // Since at the start of the rule ordering requirement is not satisfied // EnforceDistribution rule doesn't satisfy this requirement either. assert_plan!(dist_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3446,10 +3399,10 @@ fn put_sort_when_input_is_valid() -> Result<()> { // Ordering requirement of sort required exec is satisfied // by existing ordering at the source. assert_plan!(physical_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; @@ -3459,10 +3412,10 @@ SortRequiredExec: [a@0 ASC] // Since at the start of the rule ordering requirement is satisfied // EnforceDistribution rule satisfy this requirement also. assert_plan!(dist_plan, @r" -SortRequiredExec: [a@0 ASC] - FilterExec: c@2 = 0 - DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet -"); + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3486,10 +3439,10 @@ fn do_not_add_unnecessary_hash() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3516,14 +3469,14 @@ fn do_not_add_unnecessary_hash2() -> Result<()> { let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] - RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 - DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet -"); + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Since hash requirements of this operator is satisfied. There shouldn't be // a hash repartition here let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); @@ -3537,17 +3490,15 @@ fn optimize_away_unnecessary_repartition() -> Result<()> { let physical_plan = coalesce_partitions_exec(repartition_exec(parquet_exec())); assert_plan!(physical_plan, @r" -CoalescePartitionsExec - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, - @r" -DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet"); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3561,23 +3512,23 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { ))); assert_plan!(physical_plan, @r" -FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - CoalescePartitionsExec - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); assert_plan!(plan_distrib, @r" -FilterExec: c@2 = 0 - FilterExec: c@2 = 0 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + FilterExec: c@2 = 0 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); assert_plan!(plan_distrib, plan_sort); @@ -3601,29 +3552,29 @@ async fn test_distribute_sort_parquet() -> Result<()> { // prior to optimization, this is the starting plan assert_plan!(physical_plan, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the enforce distribution run does. let plan_distribution = test_config.to_plan(physical_plan.clone(), &[Run::Distribution]); assert_plan!(plan_distribution, @r" -SortExec: expr=[c@2 ASC], preserve_partitioning=[false] - CoalescePartitionsExec - DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the sort parallelization (in enforce sorting), does after the enforce distribution changes let plan_both = test_config.to_plan(physical_plan, &[Run::Distribution, Run::Sorting]); assert_plan!(plan_both, @r" -SortPreservingMergeExec: [c@2 ASC] - SortExec: expr=[c@2 ASC], preserve_partitioning=[true] - DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet -"); + SortPreservingMergeExec: [c@2 ASC] + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3650,10 +3601,10 @@ async fn test_distribute_sort_memtable() -> Result<()> { // this is the final, optimized plan assert_plan!(physical_plan, @r" -SortPreservingMergeExec: [id@0 ASC NULLS LAST] - SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] - DataSourceExec: partitions=3, partition_sizes=[34, 33, 33] -"); + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] + DataSourceExec: partitions=3, partition_sizes=[34, 33, 33] + "); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e3a0eb7e1aa6f..47e3adb455117 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -19,19 +19,20 @@ use std::sync::Arc; use crate::memory_limit::DummyStreamPartition; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, - check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, - create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, - hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + RequirementsTestExec, aggregate_exec, bounded_window_exec, + bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, + coalesce_partitions_exec, create_test_schema, create_test_schema2, + create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, + local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, projection_exec, + repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + union_exec, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; use datafusion_common::{Result, TableReference}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -65,17 +66,22 @@ use datafusion_execution::TaskContext; use datafusion_catalog::streaming::StreamingTable; use futures::StreamExt; -use insta::{assert_snapshot, Settings}; +use insta::{Settings, assert_snapshot}; /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(false, 0, 0)), + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)), ) .with_file(PartitionedFile::new("x".to_string(), 100)); if let Some(ordering) = LexOrdering::new(sort_exprs) { @@ -361,8 +367,8 @@ async fn test_union_inputs_different_sorted2() -> Result<()> { #[tokio::test] // Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, @r" @@ -387,8 +393,8 @@ async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_reparti #[tokio::test] // Test with `repartition_sorts` disabled, causing a full resort of the data -async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false( -) -> Result<()> { +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { assert_snapshot!( union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, @r" @@ -659,21 +665,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); assert_snapshot!(test.run(), @r" - Input Plan: + Input / Optimized Plan: SortPreservingMergeExec: [nullable_col@0 ASC] UnionExec SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet - - Optimized Plan: - SortPreservingMergeExec: [nullable_col@0 ASC] - UnionExec - SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet - SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet "); // Union preserves the inputs ordering, and we should not change any of the SortExecs under UnionExec @@ -773,8 +771,8 @@ async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { } #[tokio::test] -async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( -) -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns() +-> Result<()> { let schema = create_test_schema()?; let source = parquet_exec(schema.clone()); let ordering = [sort_expr_options( @@ -1072,8 +1070,8 @@ async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { } #[tokio::test] -async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement( -) -> Result<()> { +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement() +-> Result<()> { let schema = create_test_schema()?; let source = parquet_exec(schema.clone()); let ordering = [sort_expr_options( @@ -1342,12 +1340,12 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1359,13 +1357,13 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1432,12 +1430,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1449,12 +1447,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1466,13 +1464,13 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1515,13 +1513,13 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] @@ -1542,12 +1540,12 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC] - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet Optimized Plan: - SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] @@ -1626,13 +1624,13 @@ async fn test_with_lost_ordering_unbounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] Optimized Plan: SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] "); @@ -1644,13 +1642,13 @@ async fn test_with_lost_ordering_unbounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] Optimized Plan: SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] "); @@ -1669,7 +1667,7 @@ async fn test_with_lost_ordering_bounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false "); @@ -1681,14 +1679,14 @@ async fn test_with_lost_ordering_bounded() -> Result<()> { SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false Optimized Plan: SortPreservingMergeExec: [a@0 ASC] SortExec: expr=[a@0 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false "); @@ -1710,7 +1708,7 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { Input / Optimized Plan: SortExec: expr=[b@1 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false "); @@ -1739,13 +1737,13 @@ async fn test_pushdown_through_spm() -> Result<()> { Input Plan: SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false Optimized Plan: SortPreservingMergeExec: [a@0 ASC, b@1 ASC] SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false "); Ok(()) @@ -1769,7 +1767,7 @@ async fn test_window_multi_layer_requirement() -> Result<()> { BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] SortPreservingMergeExec: [a@0 ASC, b@1 ASC] RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false @@ -1964,7 +1962,7 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { assert_snapshot!(test.run(), @r" Input Plan: RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] @@ -2011,7 +2009,7 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { AggregateExec: mode=Final, gby=[], aggr=[] SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [non_nullable_col@1 ASC] SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] @@ -2360,7 +2358,7 @@ async fn test_commutativity() -> Result<()> { assert_snapshot!(displayable(orig_plan.as_ref()).indent(true), @r#" SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] DataSourceExec: partitions=1, partition_sizes=[0] "#); diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs index ef233e222912c..de7611ff211a5 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs @@ -31,7 +31,7 @@ use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + BoundedWindowAggExec, WindowAggExec, create_window_expr, }; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode}; use insta::assert_snapshot; diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index de61149508904..f480de71d6285 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -18,7 +18,7 @@ use std::sync::{Arc, LazyLock}; use arrow::{ - array::record_batch, + array::{Float64Array, Int32Array, RecordBatch, StringArray, record_batch}, datatypes::{DataType, Field, Schema, SchemaRef}, util::pretty::pretty_format_batches, }; @@ -27,8 +27,8 @@ use datafusion::{ assert_batches_eq, logical_expr::Operator, physical_plan::{ - expressions::{BinaryExpr, Column, Literal}, PhysicalExpr, + expressions::{BinaryExpr, Column, Literal}, }, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, scalar::ScalarValue, @@ -36,20 +36,25 @@ use datafusion::{ use datafusion_catalog::memory::DataSourceExec; use datafusion_common::config::ConfigOptions; use datafusion_datasource::{ - file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, PartitionedFile, + PartitionedFile, file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::ScalarUDF; use datafusion_functions::math::random::RandomFunc; -use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::{ + count::count_udaf, + min_max::{max_udaf, min_udaf}, +}; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, expressions::col}; use datafusion_physical_expr::{ - aggregate::AggregateExprBuilder, Partitioning, ScalarFunctionExpr, + Partitioning, ScalarFunctionExpr, + aggregate::{AggregateExprBuilder, AggregateFunctionExpr}, }; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::{ - filter_pushdown::FilterPushdown, PhysicalOptimizerRule, + PhysicalOptimizerRule, filter_pushdown::FilterPushdown, }; use datafusion_physical_plan::{ + ExecutionPlan, aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, coalesce_batches::CoalesceBatchesExec, coalesce_partitions::CoalescePartitionsExec, @@ -57,13 +62,13 @@ use datafusion_physical_plan::{ filter::FilterExec, repartition::RepartitionExec, sorts::sort::SortExec, - ExecutionPlan, }; use datafusion_physical_plan::union::UnionExec; use futures::StreamExt; -use object_store::{memory::InMemory, ObjectStore}; -use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; +use object_store::{ObjectStore, memory::InMemory}; +use regex::Regex; +use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test}; use crate::physical_optimizer::filter_pushdown::util::TestSource; @@ -177,12 +182,14 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8View, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8View, false), @@ -194,12 +201,14 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8View, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("d", DataType::Utf8, false), Field::new("e", DataType::Utf8View, false), @@ -272,13 +281,14 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { stream.next().await.unwrap().unwrap(); // Test that filters are pushed down correctly to each side of the join + // NOTE: We dropped the CASE expression here because we now optimize that away if there's only 1 partition insta::assert_snapshot!( format_plan_for_test(&plan), @r" - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab AND d@0 IN (SET) ([aa, ab]) ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] " ); } @@ -293,12 +303,14 @@ async fn test_static_filter_pushdown_through_hash_join() { use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8View, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8View, false), @@ -310,12 +322,14 @@ async fn test_static_filter_pushdown_through_hash_join() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8View, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("d", DataType::Utf8, false), Field::new("e", DataType::Utf8View, false), @@ -556,15 +570,14 @@ fn test_pushdown_through_aggregates_on_grouping_columns() { FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), ); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; let group_by = PhysicalGroupBy::new_single(vec![ (col("a", &schema()).unwrap(), "a".to_string()), (col("b", &schema()).unwrap(), "b".to_string()), @@ -859,20 +872,17 @@ async fn test_topk_filter_passes_through_coalesce_partitions() { ]; // Create a source that supports all batches - let source = Arc::new(TestSource::new(true, batches)); - - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test://").unwrap(), - Arc::clone(&schema()), - source, - ) - .with_file_groups(vec![ - // Partition 0 - FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), - // Partition 1 - FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), - ]) - .build(); + let source = Arc::new(TestSource::new(schema(), true, batches)); + + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file_groups(vec![ + // Partition 0 + FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), + // Partition 1 + FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), + ]) + .build(); let scan = DataSourceExec::from_data_source(base_config); @@ -972,12 +982,14 @@ async fn test_hashjoin_dynamic_filter_pushdown() { use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -989,12 +1001,14 @@ async fn test_hashjoin_dynamic_filter_pushdown() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1077,7 +1091,7 @@ async fn test_hashjoin_dynamic_filter_pushdown() { @r" - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] " ); } @@ -1140,12 +1154,14 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { // +---------------+------------------------------------------------------------+ // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1157,12 +1173,14 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1308,10 +1326,14 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 2 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 4 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ] " ); + // When hash collisions force all data into a single partition, we optimize away the CASE expression. + // This avoids calling create_hashes() for every row on the probe side, since hash % 1 == 0 always, + // meaning the WHEN 0 branch would always match. This optimization is also important for primary key + // joins or any scenario where all build-side data naturally lands in one partition. #[cfg(feature = "force_hash_collisions")] insta::assert_snapshot!( format!("{}", format_plan_for_test(&plan)), @@ -1325,7 +1347,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] " ); @@ -1356,12 +1378,14 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { use datafusion_common::JoinType; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) // Extra column not used in join - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1373,12 +1397,14 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab", "ac", "ad"]), - ("b", Utf8, ["ba", "bb", "bc", "bd"]), - ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1502,7 +1528,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] " ); @@ -1535,10 +1561,9 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { // Create test data for three tables: t1, t2, t3 // t1: small table with limited values (will be build side of outer join) - let t1_batches = - vec![ - record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), - ]; + let t1_batches = vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), + ]; let t1_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("x", DataType::Float64, false), @@ -1549,12 +1574,14 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { .build(); // t2: larger table (will be probe side of inner join, build side of outer join) - let t2_batches = vec![record_batch!( - ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), - ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), - ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) - ) - .unwrap()]; + let t2_batches = vec![ + record_batch!( + ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), + ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), + ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) + ) + .unwrap(), + ]; let t2_schema = Arc::new(Schema::new(vec![ Field::new("b", DataType::Utf8, false), Field::new("c", DataType::Utf8, false), @@ -1566,11 +1593,13 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { .build(); // t3: largest table (will be probe side of inner join) - let t3_batches = vec![record_batch!( - ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), - ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - ) - .unwrap()]; + let t3_batches = vec![ + record_batch!( + ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), + ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + ) + .unwrap(), + ]; let t3_schema = Arc::new(Schema::new(vec![ Field::new("d", DataType::Utf8, false), Field::new("z", DataType::Float64, false), @@ -1670,8 +1699,8 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab AND b@0 IN (SET) ([aa, ab]) ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb AND d@0 IN (SET) ([ca, cb]) ] " ); } @@ -1682,12 +1711,14 @@ async fn test_hashjoin_parent_filter_pushdown() { use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; // Create build side with limited values - let build_batches = vec![record_batch!( - ("a", Utf8, ["aa", "ab"]), - ("b", Utf8, ["ba", "bb"]), - ("c", Float64, [1.0, 2.0]) - ) - .unwrap()]; + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; let build_side_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), @@ -1699,12 +1730,14 @@ async fn test_hashjoin_parent_filter_pushdown() { .build(); // Create probe side with more values - let probe_batches = vec![record_batch!( - ("d", Utf8, ["aa", "ab", "ac", "ad"]), - ("e", Utf8, ["ba", "bb", "bc", "bd"]), - ("f", Float64, [1.0, 2.0, 3.0, 4.0]) - ) - .unwrap()]; + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; let probe_side_schema = Arc::new(Schema::new(vec![ Field::new("d", DataType::Utf8, false), Field::new("e", DataType::Utf8, false), @@ -1827,7 +1860,7 @@ STORED AS PARQUET; assert!(explain.contains("output_rows=128")); // Read 1 row group assert!(explain.contains("t@0 < 1372708809")); // Dynamic filter was applied assert!( - explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99872"), + explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99.87 K"), "{explain}" ); // Pushdown pruned most rows @@ -1892,16 +1925,438 @@ fn col_lit_predicate( )) } +// ==== Aggregate Dynamic Filter tests ==== + +// ---- Test Utilities ---- +struct AggregateDynFilterCase<'a> { + schema: SchemaRef, + batches: Vec, + aggr_exprs: Vec, + expected_before: Option<&'a str>, + expected_after: Option<&'a str>, + scan_support: bool, +} + +async fn run_aggregate_dyn_filter_case(case: AggregateDynFilterCase<'_>) { + let AggregateDynFilterCase { + schema, + batches, + aggr_exprs, + expected_before, + expected_after, + scan_support, + } = case; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(scan_support) + .with_batches(batches) + .build(); + + let aggr_exprs: Vec<_> = aggr_exprs + .into_iter() + .map(|expr| Arc::new(expr) as Arc) + .collect(); + let aggr_len = aggr_exprs.len(); + + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + aggr_exprs, + vec![None; aggr_len], + scan, + Arc::clone(&schema), + ) + .unwrap(), + ); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let before = format_plan_for_test(&optimized); + if let Some(expected) = expected_before { + assert!( + before.contains(expected), + "expected `{expected}` before execution, got: {before}" + ); + } else { + assert!( + !before.contains("DynamicFilter ["), + "dynamic filter unexpectedly present before execution: {before}" + ); + } + + let session_ctx = SessionContext::new(); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let task_ctx = session_ctx.state().task_ctx(); + let mut stream = optimized.execute(0, Arc::clone(&task_ctx)).unwrap(); + let _ = stream.next().await.transpose().unwrap(); + + let after = format_plan_for_test(&optimized); + if let Some(expected) = expected_after { + assert!( + after.contains(expected), + "expected `{expected}` after execution, got: {after}" + ); + } else { + assert!( + !after.contains("DynamicFilter ["), + "dynamic filter unexpectedly present after execution: {after}" + ); + } +} + +// ---- Test Cases ---- +// Cases covered below: +// 1. `min(a)` and `max(a)` baseline. +// 2. Unsupported expression input (`min(a+1)`). +// 3. Multiple supported columns (same column vs different columns). +// 4. Mixed supported + unsupported aggregates. +// 5. Entirely NULL input to surface current bound behavior. +// 6. End-to-end tests on parquet files + +/// `MIN(a)`: able to pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_simple() { + // Single min(a) showcases the base case. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 ]"), + scan_support: true, + }) + .await; +} + +/// `MAX(a)`: able to pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_max_simple() { + // Single max(a) mirrors the base case on the upper bound. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 > 8 ]"), + scan_support: true, + }) + .await; +} + +/// `MIN(a+1)`: Can't pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_expression_not_supported() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let expr: Arc = Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + let min_expr = AggregateExprBuilder::new(min_udaf(), vec![expr]) + .schema(Arc::clone(&schema)) + .alias("min_a_plus_one") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: None, + expected_after: None, + scan_support: true, + }) + .await; +} + +/// `MIN(a), MAX(a)`: Pushdown dynamic filter like `(a<1) or (a>8)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_max_same_column() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr, max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR a@0 > 8 ]"), + scan_support: true, + }) + .await; +} + +/// `MIN(a), MAX(b)`: Pushdown dynamic filter like `(a<1) or (b>9)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_max_different_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + let batches = vec![ + record_batch!(("a", Int32, [5, 1, 3, 8]), ("b", Int32, [7, 2, 4, 9])).unwrap(), + ]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("b", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_b") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr, max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR b@1 > 9 ]"), + scan_support: true, + }) + .await; +} + +/// Mix of supported/unsupported aggregates retains only the valid ones. +/// `MIN(a), MAX(a), MAX(b), MIN(c+1)`: Pushdown dynamic filter like `(a<1) or (a>8) OR (b>12)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_multiple_mixed_expressions() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let batches = vec![ + record_batch!( + ("a", Int32, [5, 1, 3, 8]), + ("b", Int32, [10, 4, 6, 12]), + ("c", Int32, [100, 70, 90, 110]) + ) + .unwrap(), + ]; + + let min_a = AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_a = AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + let max_b = AggregateExprBuilder::new(max_udaf(), vec![col("b", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_b") + .build() + .unwrap(); + let expr_c: Arc = Arc::new(BinaryExpr::new( + col("c", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + let min_c_expr = AggregateExprBuilder::new(min_udaf(), vec![expr_c]) + .schema(Arc::clone(&schema)) + .alias("min_c_plus_one") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_a, max_a, max_b, min_c_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR a@0 > 8 OR b@1 > 12 ]"), + scan_support: true, + }) + .await; +} + +/// Don't tighten the dynamic filter if all inputs are null +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_all_nulls() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [None, None, None, None])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: Some("DynamicFilter [ empty ]"), + // After reading the input it hasn't a meaningful bound to update, so the + // predicate `true` means don't filter out anything + expected_after: Some("DynamicFilter [ true ]"), + scan_support: true, + }) + .await; +} + +/// Test aggregate dynamic filter is working when reading parquet files +/// +/// Runs 'select max(id) from test_table where id > 1', and ensure some file ranges +/// pruned by the dynamic filter. +#[tokio::test] +async fn test_aggregate_dynamic_filter_parquet_e2e() { + let config = SessionConfig::new() + .with_collect_statistics(true) + .with_target_partitions(2) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true) + .set_bool("datafusion.execution.parquet.pushdown_filters", true); + let ctx = SessionContext::new_with_config(config); + + let data_path = format!( + "{}/tests/data/test_statistics_per_partition/", + env!("CARGO_MANIFEST_DIR") + ); + + ctx.register_parquet("test_table", &data_path, ParquetReadOptions::default()) + .await + .unwrap(); + + // partition 1: + // files: ..03-01(id=4), ..03-02(id=3) + // partition 1: + // files: ..03-03(id=2), ..03-04(id=1) + // + // In partition 1, after reading the first file, the dynamic filter will be update + // to "id > 4", so the `..03-02` file must be able to get pruned out + let df = ctx + .sql("explain analyze select max(id) from test_table where id > 1") + .await + .unwrap(); + + let result = df.collect().await.unwrap(); + + let formatted = pretty_format_batches(&result).unwrap(); + let explain_analyze = format!("{formatted}"); + + // Capture "2" from "files_ranges_pruned_statistics=4 total → 2 matched" + let re = Regex::new( + r"files_ranges_pruned_statistics\s*=\s*(\d+)\s*total\s*[→>\-]\s*(\d+)\s*matched", + ) + .unwrap(); + + if let Some(caps) = re.captures(&explain_analyze) { + let matched_num: i32 = caps[2].parse().unwrap(); + assert!( + matched_num < 4, + "Total 4 files, if some pruned, the matched count is < 4" + ); + } else { + unreachable!("metrics should exist") + } +} + +/// Non-partial (Single) aggregates should skip dynamic filter initialization. +#[test] +fn test_aggregate_dynamic_filter_not_created_for_single_mode() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(vec![]), + vec![min_expr.into()], + vec![None], + scan, + Arc::clone(&schema), + ) + .unwrap(), + ); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let formatted = format_plan_for_test(&optimized); + assert!( + !formatted.contains("DynamicFilter ["), + "dynamic filter should not be created for AggregateMode::Single: {formatted}" + ); +} + #[tokio::test] async fn test_aggregate_filter_pushdown() { // Test that filters can pass through AggregateExec even with aggregate functions // when the filter references grouping columns // Simulates: SELECT a, COUNT(b) FROM table WHERE a = 'x' GROUP BY a - let batches = - vec![ - record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), - ]; + let batches = vec![ + record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), + ]; let scan = TestScanBuilder::new(schema()) .with_support(true) @@ -1962,10 +2417,9 @@ async fn test_no_pushdown_filter_on_aggregate_result() { // SELECT a, COUNT(b) as cnt FROM table GROUP BY a HAVING cnt > 5 // The filter on 'cnt' cannot be pushed down because it's an aggregate result - let batches = - vec![ - record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), - ]; + let batches = vec![ + record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), + ]; let scan = TestScanBuilder::new(schema()) .with_support(true) @@ -2034,15 +2488,14 @@ fn test_pushdown_filter_on_non_first_grouping_column() { // The filter is on 'b' (second grouping column), should push down let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; let group_by = PhysicalGroupBy::new_single(vec![ (col("a", &schema()).unwrap(), "a".to_string()), @@ -2085,15 +2538,14 @@ fn test_no_pushdown_grouping_sets_filter_on_missing_column() { // Test that filters on columns missing from some grouping sets are NOT pushed through let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; // Create GROUPING SETS with (a, b) and (b) let group_by = PhysicalGroupBy::new( @@ -2115,6 +2567,7 @@ fn test_no_pushdown_grouping_sets_filter_on_missing_column() { vec![false, false], // (a, b) - both present vec![true, false], // (b) - a is NULL, b present ], + true, ); let aggregate = Arc::new( @@ -2155,15 +2608,14 @@ fn test_pushdown_grouping_sets_filter_on_common_column() { // Test that filters on columns present in ALL grouping sets ARE pushed through let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; // Create GROUPING SETS with (a, b) and (b) let group_by = PhysicalGroupBy::new( @@ -2185,6 +2637,7 @@ fn test_pushdown_grouping_sets_filter_on_common_column() { vec![false, false], // (a, b) - both present vec![true, false], // (b) - a is NULL, b present ], + true, ); let aggregate = Arc::new( @@ -2226,15 +2679,14 @@ fn test_pushdown_with_empty_group_by() { // There are no grouping columns, so the filter should still push down let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; // Empty GROUP BY - no grouping columns let group_by = PhysicalGroupBy::new_single(vec![]); @@ -2286,15 +2738,14 @@ fn test_pushdown_with_computed_grouping_key() { )) as Arc; let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; let c_plus_one = Arc::new(BinaryExpr::new( col("c", &schema()).unwrap(), @@ -2333,3 +2784,731 @@ fn test_pushdown_with_computed_grouping_key() { " ); } + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_all_partitions_empty() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test scenario where all build-side partitions are empty + // This validates the code path that sets the filter to `false` when no rows can match + + // Create empty build side + let build_batches = vec![]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with some data + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac"]), + ("b", Utf8, ["ba", "bb", "bc"]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides + let partition_count = 4; + + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); + + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); + + // Create HashJoinExec + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_coalesce, + probe_coalesce, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let plan = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + + // Apply the filter pushdown optimizer + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(plan, config.options()).unwrap(); + + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + // Execute all partitions (required for partitioned hash join coordination) + let _batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ false ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_with_nulls() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test scenario where build side has NULL values in join keys + // This validates NULL handling in bounds computation and filter generation + + // Create build side with NULL values + let build_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), // nullable + Field::new("b", DataType::Int32, true), // nullable + ])), + vec![ + Arc::new(StringArray::from(vec![Some("aa"), None, Some("ab")])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + ], + ) + .unwrap(); + let build_batches = vec![build_batch]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with nullable fields + let probe_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Float64, false), + ])), + vec![ + Arc::new(StringArray::from(vec![ + Some("aa"), + Some("ab"), + Some("ac"), + None, + ])), + Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(4), Some(5)])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(); + let probe_batches = vec![probe_batch]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec in CollectLeft mode (simpler for this test) + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let plan = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + + // Apply the filter pushdown optimizer + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(plan, config.options()).unwrap(); + + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + // Execute all partitions (required for partitioned hash join coordination) + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= 1 AND b@1 <= 2 AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:1}, {c0:,c1:2}, {c0:ab,c1:}]) ] + " + ); + + #[rustfmt::skip] + let expected = [ + "+----+---+----+---+-----+", + "| a | b | a | b | c |", + "+----+---+----+---+-----+", + "| aa | 1 | aa | 1 | 1.0 |", + "+----+---+----+---+-----+", + ]; + assert_batches_eq!(&expected, &batches); +} + +/// Test that when hash_join_inlist_pushdown_max_size is set to a very small value, +/// the HashTable strategy is used instead of InList strategy, even with small build sides. +/// This test is identical to test_hashjoin_dynamic_filter_pushdown_partitioned except +/// for the config setting that forces the HashTable strategy. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); + + // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_coalesce, + probe_coalesce, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Top-level CoalesceBatchesExec + let cb = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +/// Test that when hash_join_inlist_pushdown_max_size is set to a very small value, +/// the HashTable strategy is used instead of InList strategy in CollectLeft mode. +/// This test is identical to test_hashjoin_dynamic_filter_pushdown_collect_left except +/// for the config setting that forces the HashTable strategy. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) -> CoalesceBatchesExec + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_coalesce, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Top-level CoalesceBatchesExec + let cb = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +/// Test HashTable strategy with integer multi-column join keys. +/// Verifies that hash_lookup works correctly with integer data types. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_integer_keys() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with integer keys + let build_batches = vec![ + record_batch!( + ("id1", Int32, [1, 2]), + ("id2", Int32, [10, 20]), + ("value", Float64, [100.0, 200.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id1", DataType::Int32, false), + Field::new("id2", DataType::Int32, false), + Field::new("value", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more integer rows + let probe_batches = vec![ + record_batch!( + ("id1", Int32, [1, 2, 3, 4]), + ("id2", Int32, [10, 20, 30, 40]), + ("data", Utf8, ["a", "b", "c", "d"]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id1", DataType::Int32, false), + Field::new("id2", DataType::Int32, false), + Field::new("data", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create join on multiple integer columns + let on = vec![ + ( + col("id1", &build_side_schema).unwrap(), + col("id1", &probe_side_schema).unwrap(), + ), + ( + col("id2", &build_side_schema).unwrap(), + col("id2", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let plan = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + + // Apply optimization with forced HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify hash_lookup is used + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + // Only 2 rows from probe side match the build side + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +-----+-----+-------+-----+-----+------+ + | id1 | id2 | value | id1 | id2 | data | + +-----+-----+-------+-----+-----+------+ + | 1 | 10 | 100.0 | 1 | 10 | a | + | 2 | 20 | 200.0 | 2 | 20 | b | + +-----+-----+-------+-----+-----+------+ + ", + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs index 7d8a9c7c2125c..1afdc4823f0a4 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -18,27 +18,24 @@ use arrow::datatypes::SchemaRef; use arrow::{array::RecordBatch, compute::concat_batches}; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; -use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_common::{Result, config::ConfigOptions, internal_err}; use datafusion_datasource::{ - file::FileSource, file_scan_config::FileScanConfig, + PartitionedFile, file::FileSource, file_scan_config::FileScanConfig, file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, - file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, - schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, - TableSchema, + file_stream::FileOpener, source::DataSourceExec, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::filter::batch_filter; use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; use datafusion_physical_plan::{ - displayable, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, filter::FilterExec, filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPropagation, }, metrics::ExecutionPlanMetricsSet, - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use futures::StreamExt; use futures::{FutureExt, Stream}; @@ -53,7 +50,6 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - schema: Option, projection: Option>, predicate: Option>, } @@ -61,6 +57,9 @@ pub struct TestOpener { impl FileOpener for TestOpener { fn open(&self, _partitioned_file: PartitionedFile) -> Result { let mut batches = self.batches.clone(); + if self.batches.is_empty() { + return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); + } if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -71,23 +70,18 @@ impl FileOpener for TestOpener { } batches = new_batches.into_iter().collect(); } - if let Some(schema) = &self.schema { - let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); - let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); - let mut new_batches = Vec::new(); - for batch in batches { - let batch = if let Some(predicate) = &self.predicate { - batch_filter(&batch, predicate)? - } else { - batch - }; - let batch = batch.project(&projection).unwrap(); - let batch = mapper.map_batch(batch).unwrap(); - new_batches.push(batch); - } - batches = new_batches; + let mut new_batches = Vec::new(); + for batch in batches { + let batch = if let Some(predicate) = &self.predicate { + batch_filter(&batch, predicate)? + } else { + batch + }; + new_batches.push(batch); } + batches = new_batches; + if let Some(projection) = &self.projection { batches = batches .into_iter() @@ -102,26 +96,29 @@ impl FileOpener for TestOpener { } /// A placeholder data source that accepts filter pushdown -#[derive(Clone, Default)] +#[derive(Clone)] pub struct TestSource { support: bool, predicate: Option>, - statistics: Option, batch_size: Option, batches: Vec, - schema: Option, metrics: ExecutionPlanMetricsSet, projection: Option>, - schema_adapter_factory: Option>, + table_schema: datafusion_datasource::TableSchema, } impl TestSource { - pub fn new(support: bool, batches: Vec) -> Self { + pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { + let table_schema = + datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), batches, - ..Default::default() + predicate: None, + batch_size: None, + projection: None, + table_schema, } } } @@ -132,14 +129,13 @@ impl FileSource for TestSource { _object_store: Arc, _base_config: &FileScanConfig, _partition: usize, - ) -> Arc { - Arc::new(TestOpener { + ) -> Result> { + Ok(Arc::new(TestOpener { batches: self.batches.clone(), batch_size: self.batch_size, - schema: self.schema.clone(), projection: self.projection.clone(), predicate: self.predicate.clone(), - }) + })) } fn filter(&self) -> Option> { @@ -157,43 +153,10 @@ impl FileSource for TestSource { }) } - fn with_schema(&self, schema: TableSchema) -> Arc { - assert!( - schema.table_partition_cols().is_empty(), - "TestSource does not support partition columns" - ); - Arc::new(TestSource { - schema: Some(schema.file_schema().clone()), - ..self.clone() - }) - } - - fn with_projection(&self, config: &FileScanConfig) -> Arc { - Arc::new(TestSource { - projection: config.projection_exprs.as_ref().map(|p| p.column_indices()), - ..self.clone() - }) - } - - fn with_statistics(&self, statistics: Statistics) -> Arc { - Arc::new(TestSource { - statistics: Some(statistics), - ..self.clone() - }) - } - fn metrics(&self) -> &ExecutionPlanMetricsSet { &self.metrics } - fn statistics(&self) -> Result { - Ok(self - .statistics - .as_ref() - .expect("statistics not set") - .clone()) - } - fn file_type(&self) -> &str { "test" } @@ -247,18 +210,8 @@ impl FileSource for TestSource { } } - fn with_schema_adapter_factory( - &self, - schema_adapter_factory: Arc, - ) -> Result> { - Ok(Arc::new(Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self.clone() - })) - } - - fn schema_adapter_factory(&self) -> Option> { - self.schema_adapter_factory.clone() + fn table_schema(&self) -> &datafusion_datasource::TableSchema { + &self.table_schema } } @@ -289,14 +242,15 @@ impl TestScanBuilder { } pub fn build(self) -> Arc { - let source = Arc::new(TestSource::new(self.support, self.batches)); - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test://").unwrap(), + let source = Arc::new(TestSource::new( Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile::new("test.parquet", 123)) - .build(); + self.support, + self.batches, + )); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file(PartitionedFile::new("test.parquet", 123)) + .build(); DataSourceExec::from_data_source(base_config) } } @@ -335,11 +289,12 @@ impl TestStream { /// least one entry in data (for the schema) pub fn new(data: Vec) -> Self { // check that there is at least one entry in data and that all batches have the same schema - assert!(!data.is_empty(), "data must not be empty"); - assert!( - data.iter().all(|batch| batch.schema() == data[0].schema()), - "all batches must have the same schema" - ); + if let Some(first) = data.first() { + assert!( + data.iter().all(|batch| batch.schema() == first.schema()), + "all batches must have the same schema" + ); + } Self { data, ..Default::default() diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index f9d3a045469e1..37bcefd418bdb 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -26,27 +26,27 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::{ColumnStatistics, JoinType, ScalarValue, stats::Precision}; use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; +use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; -use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; +use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::utils::JoinFilter; use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; use futures::Stream; @@ -949,10 +949,10 @@ impl Stream for UnboundedStream { mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if let Some(val) = self.batch_produce { - if val <= self.count { - return Poll::Ready(None); - } + if let Some(val) = self.batch_produce + && val <= self.count + { + return Poll::Ready(None); } self.count += 1; Poll::Ready(Some(Ok(self.batch.clone()))) @@ -1088,9 +1088,10 @@ pub struct StatisticsExec { impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { assert_eq!( - stats.column_statistics.len(), schema.fields().len(), - "if defined, the column statistics vector length should be the number of fields" - ); + stats.column_statistics.len(), + schema.fields().len(), + "if defined, the column statistics vector length should be the number of fields" + ); let cache = Self::compute_properties(Arc::new(schema.clone())); Self { stats, diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index 56d48901f284d..b32a9bbd25432 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -27,16 +27,16 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{col, lit, BinaryExpr}; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -96,51 +96,51 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" - ]; + "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) } #[test] -fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( -) -> Result<()> { +fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero() +-> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" - ]; + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) } #[test] -fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( -) -> Result<()> { +fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit() +-> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; @@ -152,14 +152,14 @@ fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limi let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " LocalLimitExec: fetch=5", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); let after_optimize = @@ -170,8 +170,8 @@ fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limi " CoalesceBatchesExec: target_batch_size=8192, fetch=5", " FilterExec: c3@2 > 0", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) @@ -187,30 +187,29 @@ fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " FilterExec: c3@2 > 0", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " GlobalLimitExec: skip=0, fetch=5", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " FilterExec: c3@2 > 0, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( -) -> Result<()> { +fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version() +-> Result<()> { let schema = create_schema(); let streaming_table = stream_exec(&schema); let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); @@ -219,11 +218,11 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); @@ -231,10 +230,10 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) @@ -258,14 +257,14 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " SortPreservingMergeExec: [c1@0 ASC]", + " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); @@ -273,13 +272,13 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "SortPreservingMergeExec: [c1@0 ASC], fetch=5", + " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) @@ -297,23 +296,23 @@ fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> R let initial = get_plan_string(&global_limit); let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(initial, expected_initial); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; let expected = [ - "CoalescePartitionsExec: fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; + "CoalescePartitionsExec: fetch=5", + " FilterExec: c3@2 > 0, fetch=5", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true", + ]; assert_eq!(get_plan_string(&after_optimize), expected); Ok(()) diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index ad15d6803413b..c523b4a752a82 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -21,8 +21,8 @@ use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, schema, - TestAggregate, + TestAggregate, build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, + schema, }; use arrow::datatypes::DataType; @@ -34,10 +34,10 @@ use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{self, cast, col}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ + ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, collect, limit::{GlobalLimitExec, LocalLimitExec}, - ExecutionPlan, }; async fn run_plan_and_format(plan: Arc) -> Result { diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 936c02eb2a02d..d11322cd26be9 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -17,18 +17,24 @@ //! Physical Optimizer integration tests +#[expect(clippy::needless_pass_by_value)] mod aggregate_statistics; mod combine_partial_final_agg; +#[expect(clippy::needless_pass_by_value)] mod enforce_distribution; mod enforce_sorting; mod enforce_sorting_monotonicity; +#[expect(clippy::needless_pass_by_value)] mod filter_pushdown; mod join_selection; +#[expect(clippy::needless_pass_by_value)] mod limit_pushdown; mod limited_distinct_aggregation; mod partition_statistics; mod projection_pushdown; +mod pushdown_sort; mod replace_with_order_preserving_variants; mod sanity_checker; +#[expect(clippy::needless_pass_by_value)] mod test_utils; mod window_optimize; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 49dc5b845605d..468d25e0e57d0 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -25,16 +25,16 @@ mod test { use datafusion::datasource::listing::ListingTable; use datafusion::prelude::SessionContext; use datafusion_catalog::TableProvider; - use datafusion_common::stats::Precision; use datafusion_common::Result; + use datafusion_common::stats::Precision; use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; - use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; use datafusion_expr_common::operator::Operator; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::{binary, col, lit, Column}; use datafusion_physical_expr::Partitioning; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{Column, binary, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ @@ -53,8 +53,8 @@ mod test { use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; use datafusion_physical_plan::{ - execute_stream_partitioned, get_plan_string, ExecutionPlan, - ExecutionPlanProperties, + ExecutionPlan, ExecutionPlanProperties, execute_stream_partitioned, + get_plan_string, }; use futures::TryStreamExt; @@ -67,7 +67,7 @@ mod test { /// - Each partition has an "id" column (INT) with the following values: /// - First partition: [3, 4] /// - Second partition: [1, 2] - /// - Each row is 110 bytes in size + /// - Each partition has 16 bytes total (Int32 id: 4 bytes × 2 rows + Date32 date: 4 bytes × 2 rows) /// /// @param create_table_sql Optional parameter to set the create table SQL /// @param target_partition Optional parameter to set the target partitions @@ -112,29 +112,51 @@ mod test { .unwrap() } + // Date32 values for test data (days since 1970-01-01): + // 2025-03-01 = 20148 + // 2025-03-02 = 20149 + // 2025-03-03 = 20150 + // 2025-03-04 = 20151 + const DATE_2025_03_01: i32 = 20148; + const DATE_2025_03_02: i32 = 20149; + const DATE_2025_03_03: i32 = 20150; + const DATE_2025_03_04: i32 = 20151; + /// Helper function to create expected statistics for a partition with Int32 column + /// + /// If `date_range` is provided, includes exact statistics for the partition date column. + /// Partition column statistics are exact because all rows in a partition share the same value. fn create_partition_statistics( num_rows: usize, total_byte_size: usize, min_value: i32, max_value: i32, - include_date_column: bool, + date_range: Option<(i32, i32)>, ) -> Statistics { + // Int32 is 4 bytes per row + let int32_byte_size = num_rows * 4; let mut column_stats = vec![ColumnStatistics { null_count: Precision::Exact(0), max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Exact(int32_byte_size), }]; - if include_date_column { + if let Some((min_date, max_date)) = date_range { + // Partition column stats are computed from partition values: + // - null_count = 0 (partition values from paths are never null) + // - min/max are the merged partition values across files in the group + // - byte_size = num_rows * 4 (Date32 is 4 bytes per row) + let date32_byte_size = num_rows * 4; column_stats.push(ColumnStatistics { - null_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(max_date))), + min_value: Precision::Exact(ScalarValue::Date32(Some(min_date))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Exact(date32_byte_size), }); } @@ -214,10 +236,22 @@ mod test { let statistics = (0..scan.output_partitioning().partition_count()) .map(|idx| scan.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Check the statistics of each partition assert_eq!(statistics.len(), 2); assert_eq!(statistics[0], expected_statistic_partition_1); @@ -246,10 +280,11 @@ mod test { let statistics = (0..projection.output_partitioning().partition_count()) .map(|idx| projection.partition_statistics(Some(idx))) .collect::>>()?; + // Projection only includes id column, not the date partition column let expected_statistic_partition_1 = - create_partition_statistics(2, 8, 3, 4, false); + create_partition_statistics(2, 8, 3, 4, None); let expected_statistic_partition_2 = - create_partition_statistics(2, 8, 1, 2, false); + create_partition_statistics(2, 8, 1, 2, None); // Check the statistics of each partition assert_eq!(statistics.len(), 2); assert_eq!(statistics[0], expected_statistic_partition_1); @@ -277,8 +312,14 @@ mod test { let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All 4 files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); assert_eq!(statistics.len(), 1); assert_eq!(statistics[0], expected_statistic_partition); // Check the statistics_by_partition with real results @@ -291,10 +332,22 @@ mod test { let sort_exec: Arc = Arc::new( SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), ); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; @@ -324,6 +377,8 @@ mod test { let filter: Arc = Arc::new(FilterExec::try_new(predicate, scan)?); let full_statistics = filter.partition_statistics(None)?; + // Filter preserves original total_rows and byte_size from input + // (4 total rows = 2 partitions * 2 rows each, byte_size = 4 * 4 = 16 bytes for int32) let expected_full_statistic = Statistics { num_rows: Precision::Inexact(0), total_byte_size: Precision::Inexact(0), @@ -334,6 +389,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Null), sum_value: Precision::Exact(ScalarValue::Null), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), }, ColumnStatistics { null_count: Precision::Exact(0), @@ -341,6 +397,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Null), sum_value: Precision::Exact(ScalarValue::Null), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), // 4 rows * 4 bytes (Date32) }, ], }; @@ -350,8 +407,31 @@ mod test { .map(|idx| filter.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_full_statistic); - assert_eq!(statistics[1], expected_full_statistic); + // Per-partition stats: each partition has 2 rows, byte_size = 2 * 4 = 8 + let expected_partition_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), // 2 rows * 4 bytes (Date32) + }, + ], + }; + assert_eq!(statistics[0], expected_partition_statistic); + assert_eq!(statistics[1], expected_partition_statistic); Ok(()) } @@ -365,10 +445,22 @@ mod test { .collect::>>()?; // Check that we have 4 partitions (2 from each scan) assert_eq!(statistics.len(), 4); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Verify first partition (from first scan) assert_eq!(statistics[0], expected_statistic_partition_1); // Verify second partition (from first scan) @@ -416,9 +508,10 @@ mod test { .collect::>>()?; assert_eq!(stats.len(), 2); + // Each partition gets half of combined input, total_rows per partition = 4 let expected_stats = Statistics { num_rows: Precision::Inexact(4), - total_byte_size: Precision::Inexact(220), + total_byte_size: Precision::Inexact(32), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -461,28 +554,76 @@ mod test { .collect::>>()?; // Check that we have 2 partitions assert_eq!(statistics.len(), 2); - let mut expected_statistic_partition_1 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_1 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(4))), - min_value: Precision::Exact(ScalarValue::Int32(Some(3))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); - let mut expected_statistic_partition_2 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_2 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(2))), - min_value: Precision::Exact(ScalarValue::Int32(Some(1))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); + // Cross join output schema: [left.id, left.date, right.id] + // Cross join doesn't propagate Column's byte_size + let expected_statistic_partition_1 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 0: ids [3,4] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; + let expected_statistic_partition_2 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 1: ids [1,2] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; assert_eq!(statistics[0], expected_statistic_partition_1); assert_eq!(statistics[1], expected_statistic_partition_2); @@ -500,10 +641,22 @@ mod test { let scan = create_scan_exec_with_statistics(None, Some(2)).await; let coalesce_batches: Arc = Arc::new(CoalesceBatchesExec::new(scan, 2)); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); let statistics = (0..coalesce_batches.output_partitioning().partition_count()) .map(|idx| coalesce_batches.partition_statistics(Some(idx))) .collect::>>()?; @@ -525,8 +678,14 @@ mod test { let scan = create_scan_exec_with_statistics(None, Some(2)).await; let coalesce_partitions: Arc = Arc::new(CoalescePartitionsExec::new(scan)); - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) .collect::>>()?; @@ -575,8 +734,14 @@ mod test { .map(|idx| global_limit.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 1); - let expected_statistic_partition = - create_partition_statistics(2, 110, 3, 4, true); + // GlobalLimit takes from first partition: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); assert_eq!(statistics[0], expected_statistic_partition); Ok(()) } @@ -601,11 +766,13 @@ mod test { ), ]); - let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .schema(Arc::clone(&scan_schema)) - .alias(String::from("COUNT(c)")) - .build() - .map(Arc::new)?]; + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?, + ]; let aggregate_exec_partial: Arc = Arc::new(AggregateExec::try_new( @@ -620,14 +787,15 @@ mod test { let plan_string = get_plan_string(&aggregate_exec_partial).swap_remove(0); assert_snapshot!( plan_string, - @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)], ordering_mode=Sorted" ); let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + // Aggregate doesn't propagate num_rows and ColumnStatistics byte_size from input let expected_p0_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -635,6 +803,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(3))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -645,7 +814,7 @@ mod test { let expected_p1_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -653,6 +822,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -849,9 +1019,10 @@ mod test { .collect::>>()?; assert_eq!(statistics.len(), 3); + // Repartition preserves original total_rows from input (4 rows total) let expected_stats = Statistics { num_rows: Precision::Inexact(1), - total_byte_size: Precision::Inexact(73), + total_byte_size: Precision::Inexact(10), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -878,9 +1049,9 @@ mod test { partition_row_counts.push(total_rows); } assert_eq!(partition_row_counts.len(), 3); - assert_eq!(partition_row_counts[0], 2); + assert_eq!(partition_row_counts[0], 1); assert_eq!(partition_row_counts[1], 2); - assert_eq!(partition_row_counts[2], 0); + assert_eq!(partition_row_counts[2], 1); Ok(()) } @@ -898,9 +1069,11 @@ mod test { let result = repartition.partition_statistics(Some(2)); assert!(result.is_err()); let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("RepartitionExec invalid partition 2 (expected less than 2)")); + assert!( + error + .to_string() + .contains("RepartitionExec invalid partition 2 (expected less than 2)") + ); let partitions = execute_stream_partitioned( repartition.clone(), @@ -953,9 +1126,10 @@ mod test { .collect::>>()?; assert_eq!(stats.len(), 2); + // Repartition preserves original total_rows from input (4 rows total) let expected_stats = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Inexact(110), + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 8631613c3925e..480f5c8cc97b1 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -24,8 +24,9 @@ use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion_datasource::TableSchema; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -34,30 +35,32 @@ use datafusion_expr::{ }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ - binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, }; use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; -use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::projection::{update_expr, ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr, update_expr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::{displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; use insta::assert_snapshot; use itertools::Itertools; @@ -229,9 +232,11 @@ fn test_update_matching_exprs() -> Result<()> { .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &child_exprs, true)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &child_exprs, true)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -368,9 +373,11 @@ fn test_update_projected_exprs() -> Result<()> { .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &proj_exprs, false)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &proj_exprs, false)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -384,14 +391,20 @@ fn create_simple_csv_exec() -> Arc { Field::new("d", DataType::Int32, true), Field::new("e", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection_indices(Some(vec![0, 1, 2, 3, 4])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![0, 1, 2, 3, 4])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -403,14 +416,20 @@ fn create_projecting_csv_exec() -> Arc { Field::new("c", DataType::Int32, true), Field::new("d", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection_indices(Some(vec![3, 2, 1])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![3, 2, 1])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -432,8 +451,8 @@ fn test_csv_after_projection() -> Result<()> { let csv = create_projecting_csv_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d"), ], csv.clone(), )?); @@ -469,9 +488,9 @@ fn test_memory_after_projection() -> Result<()> { let memory = create_projecting_memory_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a"), ], memory.clone(), )?); @@ -575,9 +594,9 @@ fn test_streaming_table_after_projection() -> Result<()> { )?; let projection = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), ], Arc::new(streaming_table) as _, )?) as _; @@ -642,28 +661,25 @@ fn test_projection_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let child_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b"), ], csv.clone(), )?); let top_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b"), ProjectionExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("c", 0)), Operator::Plus, Arc::new(Column::new("new_e", 1)), )), - "binary".to_string(), - ), - ProjectionExpr::new( - Arc::new(Column::new("new_b", 3)), - "newest_b".to_string(), + "binary", ), + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "newest_b"), ], child_projection.clone(), )?); @@ -692,10 +708,7 @@ fn test_projection_after_projection() -> Result<()> { assert_snapshot!( actual, - @r" - ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - " + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b], file_type=csv, has_header=false" ); Ok(()) @@ -731,9 +744,9 @@ fn test_output_req_after_projection() -> Result<()> { )); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], sort_req.clone(), )?); @@ -762,8 +775,7 @@ fn test_output_req_after_projection() -> Result<()> { actual, @r" OutputRequirementExec: order_by=[(b@2, asc), (c@0 + new_a@1, asc)], dist_by=HashPartitioned[[new_a@1, b@2]]) - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -805,10 +817,11 @@ fn test_output_req_after_projection() -> Result<()> { .required_input_distribution()[0] .clone() { - assert!(vec - .iter() - .zip(expected_distribution) - .all(|(actual, expected)| actual.eq(&expected))); + assert!( + vec.iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected)) + ); } else { panic!("Expected HashPartitioned distribution!"); }; @@ -823,9 +836,9 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { Arc::new(CoalescePartitionsExec::new(csv)); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], coalesce_partitions, )?); @@ -853,8 +866,7 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { actual, @r" CoalescePartitionsExec - ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, a@0 as a_new, d], file_type=csv, has_header=false " ); @@ -880,9 +892,9 @@ fn test_filter_after_projection() -> Result<()> { let filter = Arc::new(FilterExec::try_new(predicate, csv)?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], filter.clone(), )?) as _; @@ -911,8 +923,7 @@ fn test_filter_after_projection() -> Result<()> { actual, @r" FilterExec: b@1 - a_new@0 > d@2 - a_new@0 - ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_new, b, d], file_type=csv, has_header=false " ); @@ -975,17 +986,11 @@ fn test_join_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - ProjectionExpr::new( - Arc::new(Column::new("a", 5)), - "a_from_right".to_string(), - ), - ProjectionExpr::new( - Arc::new(Column::new("c", 7)), - "c_from_right".to_string(), - ), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a_from_right"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join, )?) as _; @@ -1014,10 +1019,8 @@ fn test_join_after_projection() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 - ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_from_right, c@2 as c_from_right], file_type=csv, has_header=false " ); @@ -1106,16 +1109,16 @@ fn test_join_after_required_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), - ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e"), ], join, )?) as _; @@ -1195,7 +1198,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )?) as _; let projection: Arc = Arc::new(ProjectionExec::try_new( - vec![ProjectionExpr::new(col_left_c, "c".to_string())], + vec![ProjectionExpr::new(col_left_c, "c")], Arc::clone(&join), )?) as _; let initial = displayable(projection.as_ref()).indent(true).to_string(); @@ -1285,13 +1288,10 @@ fn test_hash_join_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - ProjectionExpr::new( - Arc::new(Column::new("c", 7)), - "c_from_right".to_string(), - ), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join.clone(), )?) as _; @@ -1327,10 +1327,10 @@ fn test_hash_join_after_projection() -> Result<()> { let projection = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), ], join.clone(), )?); @@ -1371,9 +1371,9 @@ fn test_repartition_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new"), ], repartition, )?) as _; @@ -1399,8 +1399,7 @@ fn test_repartition_after_projection() -> Result<()> { actual, @r" RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1 - ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as b_new, a, d@3 as d_new], file_type=csv, has_header=false " ); @@ -1441,9 +1440,9 @@ fn test_sort_after_projection() -> Result<()> { ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], Arc::new(sort_exec), )?) as _; @@ -1470,8 +1469,7 @@ fn test_sort_after_projection() -> Result<()> { actual, @r" SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false] - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1495,9 +1493,9 @@ fn test_sort_preserving_after_projection() -> Result<()> { ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], Arc::new(sort_exec), )?) as _; @@ -1524,8 +1522,7 @@ fn test_sort_preserving_after_projection() -> Result<()> { actual, @r" SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC] - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1538,9 +1535,9 @@ fn test_union_after_projection() -> Result<()> { let union = UnionExec::try_new(vec![csv.clone(), csv.clone(), csv])?; let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), - ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), - ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], union.clone(), )?) as _; @@ -1569,12 +1566,9 @@ fn test_union_after_projection() -> Result<()> { actual, @r" UnionExec - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false - ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false " ); @@ -1589,14 +1583,23 @@ fn partitioned_data_source() -> Arc { Field::new("string_col", DataType::Utf8, true), ])); + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("partition_col", DataType::Utf8, true))], + ); let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - file_schema.clone(), - Arc::new(CsvSource::default()), + Arc::new(CsvSource::new(table_schema).with_csv_options(options)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_table_partition_cols(vec![Field::new("partition_col", DataType::Utf8, true)]) + .with_file(PartitionedFile::new("x", 100)) .with_projection_indices(Some(vec![0, 1, 2])) + .unwrap() .build(); DataSourceExec::from_data_source(config) @@ -1611,16 +1614,13 @@ fn test_partition_col_projection_pushdown() -> Result<()> { vec![ ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), ProjectionExpr::new( col("partition_col", partitioned_schema.as_ref())?, - "partition_col".to_string(), - ), - ProjectionExpr::new( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1634,10 +1634,7 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let actual = after_optimize_string.trim(); assert_snapshot!( actual, - @r" - ProjectionExec: expr=[string_col@1 as string_col, partition_col@2 as partition_col, int_col@0 as int_col] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false - " + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, partition_col, int_col], file_type=csv, has_header=false" ); Ok(()) @@ -1652,7 +1649,7 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { vec![ ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), ProjectionExpr::new( // CAST(partition_col, Utf8View) @@ -1661,12 +1658,9 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { partitioned_schema.as_ref(), DataType::Utf8View, )?, - "partition_col".to_string(), - ), - ProjectionExpr::new( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1678,11 +1672,107 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { .indent(true) .to_string(); let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col], file_type=csv, has_header=false" + ); + + Ok(()) +} + +#[test] +fn test_coalesce_batches_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )), + csv, + )?); + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new(filter, 8192)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ], + coalesce_batches, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a, b@1 as b] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@2 > 0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed down through CoalesceBatchesExec + assert_snapshot!( + actual, + @r" + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@2 > 0, projection=[a@0, b@1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + Ok(()) +} + +#[test] +fn test_cooperative_exec_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let cooperative: Arc = Arc::new(CooperativeExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ], + cooperative, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a, b@1 as b] + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed down through CooperativeExec assert_snapshot!( actual, @r" - ProjectionExec: expr=[string_col@1 as string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col@0 as int_col] - DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false " ); diff --git a/datafusion/core/tests/physical_optimizer/pushdown_sort.rs b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs new file mode 100644 index 0000000000000..caef0fba052cb --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs @@ -0,0 +1,1040 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for sort pushdown optimizer rule (Phase 1) +//! +//! Phase 1 tests verify that: +//! 1. Reverse scan is enabled (reverse_row_groups=true) +//! 2. SortExec is kept (because ordering is inexact) +//! 3. output_ordering remains unchanged +//! 4. Early termination is enabled for TopK queries +//! 5. Prefix matching works correctly + +use datafusion_physical_expr::expressions; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::pushdown_sort::PushdownSort; +use std::sync::Arc; + +use crate::physical_optimizer::test_utils::{ + OptimizationTest, coalesce_batches_exec, coalesce_partitions_exec, parquet_exec, + parquet_exec_with_sort, projection_exec, projection_exec_with_alias, + repartition_exec, schema, simple_projection_exec, sort_exec, sort_exec_with_fetch, + sort_expr, sort_expr_named, test_scan_with_ordering, +}; + +#[test] +fn test_sort_pushdown_disabled() { + // When pushdown is disabled, plan should remain unchanged + let schema = schema(); + let source = parquet_exec(schema.clone()); + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), false), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_basic_phase1() { + // Phase 1: Reverse scan enabled, Sort kept, output_ordering unchanged + let schema = schema(); + + // Source has ASC NULLS LAST ordering (default) + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC NULLS LAST ordering (exact reverse) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_with_limit_phase1() { + // Phase 1: Sort with fetch enables early termination but keeps Sort + let schema = schema(); + + // Source has ASC ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC ordering with limit + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_multiple_columns_phase1() { + // Phase 1: Sort on multiple columns - reverse multi-column ordering + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC NULLS FIRST, b DESC] ordering (exact reverse) + let reverse_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PREFIX MATCHING TESTS +// ============================================================================ + +#[test] +fn test_prefix_match_single_column() { + // Test prefix matching: source has [a DESC, b ASC], query needs [a ASC] + // After reverse: [a ASC, b DESC] which satisfies [a ASC] prefix + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request only [a ASC NULLS FIRST] - a prefix of the reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_with_limit() { + // Test prefix matching with LIMIT - important for TopK optimization + let schema = schema(); + + // Source has [a ASC, b DESC, c ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone(), b.clone().reverse(), c]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC NULLS LAST, b ASC NULLS FIRST] with LIMIT 100 + // This is a prefix (2 columns) of the reversed 3-column ordering + let prefix_ordering = + LexOrdering::new(vec![a.reverse(), b.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec_with_fetch(prefix_ordering, Some(100), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 DESC NULLS LAST, c@2 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_through_transparent_nodes() { + // Test prefix matching works through transparent nodes + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC, c DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone().reverse(), b, c.reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let coalesce = coalesce_batches_exec(source, 1024); + let repartition = repartition_exec(coalesce); + + // Request only [a ASC NULLS FIRST] - prefix of reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC, c@2 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_prefix_match_wrong_direction() { + // Test that prefix matching does NOT work if the direction is wrong + let schema = schema(); + + // Source has [a DESC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC] - same direction as source, NOT a reverse prefix + let same_direction = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let plan = sort_exec(same_direction, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_no_prefix_match_longer_than_source() { + // Test that prefix matching does NOT work if requested is longer than source + let schema = schema(); + + // Source has [a DESC] ordering (single column) + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC, b DESC] - longer than source, can't be a prefix + let longer_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(longer_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + " + ); +} + +// ============================================================================ +// ORIGINAL TESTS +// ============================================================================ + +#[test] +fn test_sort_through_coalesce_batches() { + // Sort pushes through CoalesceBatchesExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let coalesce = coalesce_batches_exec(source, 1024); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_through_repartition() { + // Sort should push through RepartitionExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_nested_sorts() { + // Nested sort operations - only innermost can be optimized + let schema = schema(); + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let inner_sort = sort_exec(desc_ordering, source); + + let sort_exprs2 = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs2, inner_sort); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_non_sort_plans_unchanged() { + // Plans without SortExec should pass through unchanged + let schema = schema(); + let source = parquet_exec(schema.clone()); + let plan = coalesce_batches_exec(source, 1024); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_optimizer_properties() { + // Test optimizer metadata + let optimizer = PushdownSort::new(); + + assert_eq!(optimizer.name(), "PushdownSort"); + assert!(optimizer.schema_check()); +} + +#[test] +fn test_sort_through_coalesce_partitions() { + // Sort should push through CoalescePartitionsExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_complex_plan_with_multiple_operators() { + // Test a complex plan with multiple operators between sort and source + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let coalesce_batches = coalesce_batches_exec(source, 1024); + let repartition = repartition_exec(coalesce_batches); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_multiple_sorts_different_columns() { + // Test nested sorts on different columns - only innermost can optimize + let schema = schema(); + let a = sort_expr("a", &schema); + let c = sort_expr("c", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // First sort by column 'a' DESC (reverse of source) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let sort1 = sort_exec(desc_ordering, source); + + // Then sort by column 'c' (different column, can't optimize) + let sort_exprs2 = LexOrdering::new(vec![c]).unwrap(); + let plan = sort_exec(sort_exprs2, sort1); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_pushdown_for_unordered_source() { + // Verify pushdown does NOT happen for sources without ordering + let schema = schema(); + let source = parquet_exec(schema.clone()); // No output_ordering + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_no_pushdown_for_non_reverse_sort() { + // Verify pushdown does NOT happen when sort doesn't reverse source ordering + let schema = schema(); + + // Source sorted by 'a' ASC + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request sort by 'b' (different column) + let sort_exprs = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_pushdown_through_blocking_node() { + // Test that pushdown works for inner sort even when outer sort is blocked + // Structure: Sort -> Aggregate (blocks pushdown) -> Sort -> Scan + // The outer sort can't push through aggregate, but the inner sort should still optimize + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use std::sync::Arc; + + let schema = schema(); + + // Bottom: DataSource with [a ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Inner Sort: [a DESC NULLS FIRST] - exact reverse, CAN push down to source + let inner_sort_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let inner_sort = sort_exec(inner_sort_ordering, source); + + // Middle: Aggregate (blocks pushdown from outer sort) + // GROUP BY a, COUNT(b) + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(expressions::Column::new("a", 0)) as _, + "a".to_string(), + )]); + + let count_expr = Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![Arc::new(expressions::Column::new("b", 1)) as _], + ) + .schema(Arc::clone(&schema)) + .alias("COUNT(b)") + .build() + .unwrap(), + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![count_expr], + vec![None], + inner_sort, + Arc::clone(&schema), + ) + .unwrap(), + ); + + // Outer Sort: [a ASC] - this CANNOT push down through aggregate + let outer_sort_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let plan = sort_exec(outer_sort_ordering, aggregate); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PROJECTION TESTS +// ============================================================================ + +#[test] +fn test_sort_pushdown_through_simple_projection() { + // Sort pushes through projection with simple column references + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b (simple column references) + let projection = simple_projection_exec(source, vec![0, 1]); // columns a, b + + // Request [a DESC] - should push through projection to source + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection_with_alias() { + // Sort pushes through projection with column aliases + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a AS id, b AS value + let projection = projection_exec_with_alias(source, vec![(0, "id"), (1, "value")]); + + // Request [id DESC] - should map to [a DESC] and push down + let id_expr = sort_expr_named("id", 0); + let desc_ordering = LexOrdering::new(vec![id_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_sort_pushdown_through_computed_projection() { + use datafusion_expr::Operator; + + // Sort should NOT push through projection with computed columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a+b as sum, c + let projection = projection_exec( + vec![ + ( + Arc::new(expressions::BinaryExpr::new( + Arc::new(expressions::Column::new("a", 0)), + Operator::Plus, + Arc::new(expressions::Column::new("b", 1)), + )) as Arc, + "sum".to_string(), + ), + ( + Arc::new(expressions::Column::new("c", 2)) as Arc, + "c".to_string(), + ), + ], + source, + ) + .unwrap(); + + // Request [sum DESC] - should NOT push down (sum is computed) + let sum_expr = sort_expr_named("sum", 0); + let desc_ordering = LexOrdering::new(vec![sum_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_projection_reordered_columns() { + // Sort pushes through projection that reorders columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT c, b, a (columns reordered) + let projection = simple_projection_exec(source, vec![2, 1, 0]); // c, b, a + + // Request [a DESC] where a is now at index 2 in projection output + let a_expr_at_2 = sort_expr_named("a", 2); + let desc_ordering = LexOrdering::new(vec![a_expr_at_2.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_with_limit() { + // Sort with LIMIT pushes through simple projection + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b + let projection = simple_projection_exec(source, vec![0, 1]); + + // Request [a DESC] with LIMIT 10 + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection_and_coalesce() { + // Sort pushes through both projection and coalesce batches + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + let coalesce = coalesce_batches_exec(source, 1024); + + // Projection: SELECT a, b + let projection = simple_projection_exec(coalesce, vec![0, 1]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_subset_of_columns() { + // Sort pushes through projection that selects subset of columns + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a (subset of columns) + let projection = simple_projection_exec(source, vec![0]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// TESTSCAN DEMONSTRATION TESTS +// ============================================================================ +// These tests use TestScan to demonstrate how sort pushdown works more clearly +// than ParquetExec. TestScan can accept ANY ordering (not just reverse) and +// displays the requested ordering explicitly in the output. + +#[test] +fn test_sort_pushdown_with_test_scan_basic() { + // Demonstrates TestScan showing requested ordering clearly + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC] ordering + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC], requested_ordering=[a@0 DESC NULLS LAST] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_multi_column() { + // Demonstrates TestScan with multi-column ordering + let schema = schema(); + + // Source has [a ASC, b DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone().reverse()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC, b ASC] ordering (reverse of source) + let reverse_ordering = LexOrdering::new(vec![a.reverse(), b]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST], requested_ordering=[a@0 DESC NULLS LAST, b@1 ASC] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_arbitrary_ordering() { + // Demonstrates that TestScan can accept ANY ordering (not just reverse) + // This is different from ParquetExec which only supports reverse scans + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a ASC, b DESC] - NOT a simple reverse, but TestScan accepts it + let mixed_ordering = LexOrdering::new(vec![a, b.reverse()]).unwrap(); + let plan = sort_exec(mixed_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC] + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC], requested_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 066e52614a12e..d93081f5ceb80 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -50,8 +50,8 @@ use datafusion_physical_plan::{ collect, displayable, ExecutionPlan, Partitioning, }; -use object_store::memory::InMemory; use object_store::ObjectStore; +use object_store::memory::InMemory; use rstest::rstest; use url::Url; @@ -138,7 +138,8 @@ impl ReplaceTest { assert!( res.is_ok(), "Some errors occurred while executing the optimized physical plan: {:?}\nPlan: {}", - res.unwrap_err(), optimized_plan_string + res.unwrap_err(), + optimized_plan_string ); } @@ -192,7 +193,7 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -202,13 +203,13 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -218,13 +219,13 @@ async fn test_replace_multiple_input_repartition_1( SortPreservingMergeExec: [a@0 ASC NULLS LAST] SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -275,21 +276,21 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] Optimized: SortPreservingMergeExec: [a@0 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] "); }, @@ -300,11 +301,11 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC "); }, @@ -315,21 +316,21 @@ async fn test_with_inter_children_change_only( SortExec: expr=[a@0 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC Optimized: SortPreservingMergeExec: [a@0 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortPreservingMergeExec: [a@0 ASC] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC "); } @@ -375,14 +376,14 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -393,7 +394,7 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -404,14 +405,14 @@ async fn test_replace_multiple_input_repartition_2( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST FilterExec: c@1 > 3 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -460,7 +461,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: @@ -468,7 +469,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -480,7 +481,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -492,7 +493,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: @@ -500,7 +501,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -551,7 +552,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: @@ -560,7 +561,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -573,7 +574,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -586,7 +587,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: @@ -595,7 +596,7 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -639,7 +640,7 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -650,7 +651,7 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because there is no executor with ordering requirement @@ -662,7 +663,7 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -712,7 +713,7 @@ async fn test_with_multiple_replaceable_repartitions( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: @@ -721,7 +722,7 @@ async fn test_with_multiple_replaceable_repartitions( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -734,7 +735,7 @@ async fn test_with_multiple_replaceable_repartitions( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -747,7 +748,7 @@ async fn test_with_multiple_replaceable_repartitions( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: @@ -756,7 +757,7 @@ async fn test_with_multiple_replaceable_repartitions( CoalesceBatchesExec: target_batch_size=8192 FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -804,7 +805,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -814,7 +815,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because ordering requirement of the executor is @@ -826,7 +827,7 @@ async fn test_not_replace_with_different_orderings( SortPreservingMergeExec: [c@1 ASC] SortExec: expr=[c@1 ASC], preserve_partitioning=[true] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -870,13 +871,13 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -886,7 +887,7 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -896,13 +897,13 @@ async fn test_with_lost_ordering( SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [a@0 ASC NULLS LAST] RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -956,22 +957,22 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] Optimized: SortPreservingMergeExec: [c@1 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -982,11 +983,11 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); }, @@ -997,22 +998,22 @@ async fn test_with_lost_and_kept_ordering( SortExec: expr=[c@1 ASC], preserve_partitioning=[true] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST Optimized: SortPreservingMergeExec: [c@1 ASC] FilterExec: c@1 > 3 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true SortExec: expr=[c@1 ASC], preserve_partitioning=[false] CoalescePartitionsExec RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); } @@ -1077,11 +1078,11 @@ async fn test_with_multiple_child_trees( HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] CoalesceBatchesExec: target_batch_size=4096 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] CoalesceBatchesExec: target_batch_size=4096 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] "); }, @@ -1093,11 +1094,11 @@ async fn test_with_multiple_child_trees( HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] CoalesceBatchesExec: target_batch_size=4096 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST CoalesceBatchesExec: target_batch_size=4096 RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST "); // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. @@ -1248,7 +1249,10 @@ fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { )], ); let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); - assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + assert_contains!( + res.unwrap_err().to_string(), + "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]" + ); // Test sort is without fetch, expected to get the fetch value from the coalesced let requirements = OrderPreservationContext::new( diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index 9867ed1733413..217570846d56e 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -30,13 +30,13 @@ use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTab use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; use datafusion_common::{JoinType, Result, ScalarValue}; -use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{Literal, col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; use async_trait::async_trait; @@ -555,11 +555,11 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] - RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] " @@ -605,8 +605,8 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 @@ -653,11 +653,11 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { assert_snapshot!( actual, @r" - SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] - RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true SortExec: expr=[a@0 ASC], preserve_partitioning=[false] DataSourceExec: partitions=1, partition_sizes=[0] " diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 8ca33f3d4abb9..5b50181d7fd3e 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -18,7 +18,7 @@ //! Test utilities for physical optimizer tests use std::any::Any; -use std::fmt::Formatter; +use std::fmt::{Display, Formatter}; use std::sync::{Arc, LazyLock}; use arrow::array::Int32Array; @@ -33,25 +33,29 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, Statistics}; +use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, Result, Statistics, internal_err, +}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{self, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; @@ -63,18 +67,17 @@ use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeE use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::tree_node::PlanContext; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec}; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, create_window_expr}; use datafusion_physical_plan::{ - displayable, DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, - Partitioning, PlanProperties, + DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, Partitioning, + PlanProperties, SortOrderPushdownResult, displayable, }; /// Create a non sorted parquet exec pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .build(); @@ -89,8 +92,7 @@ pub(crate) fn parquet_exec_with_sort( ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) @@ -106,6 +108,7 @@ fn int64_stats() -> ColumnStatistics { max_value: Precision::Exact(1_000_000.into()), min_value: Precision::Exact(0.into()), distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -127,17 +130,13 @@ pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::new(Default::default())), + Arc::new(ParquetSource::new(schema())), ) .with_file(PartitionedFile::new("x".to_string(), file_size)) .with_statistics(statistics) .build(); - assert_eq!( - config.file_source.statistics().unwrap().num_rows, - Precision::Inexact(10000) - ); + assert_eq!(config.statistics().num_rows, Precision::Inexact(10000)); DataSourceExec::from_data_source(config) } @@ -467,10 +466,11 @@ impl ExecutionPlan for RequirementsTestExec { } fn required_input_ordering(&self) -> Vec> { - vec![self - .required_input_ordering - .as_ref() - .map(|ordering| OrderingRequirements::from(ordering.clone()))] + vec![ + self.required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone())), + ] } fn maintains_input_order(&self) -> Vec { @@ -704,3 +704,278 @@ impl TestAggregate { } } } + +/// A harness for testing physical optimizers. +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + enable_sort_pushdown: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_sort_pushdown = enable_sort_pushdown; + let output_result = opt.optimize(input_plan, &config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +/// Create a simple ProjectionExec with column indices (simplified version) +pub fn simple_projection_exec( + input: Arc, + columns: Vec, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&i| { + let field = schema.field(i); + ( + Arc::new(expressions::Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a ProjectionExec with column aliases +pub fn projection_exec_with_alias( + input: Arc, + columns: Vec<(usize, &str)>, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&(i, alias)| { + ( + Arc::new(expressions::Column::new(schema.field(i).name(), i)) + as Arc, + alias.to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a sort expression with custom name and index +pub fn sort_expr_named(name: &str, index: usize) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: Arc::new(expressions::Column::new(name, index)), + options: SortOptions::default(), + } +} + +/// A test data source that can display any requested ordering +/// This is useful for testing sort pushdown behavior +#[derive(Debug, Clone)] +pub struct TestScan { + schema: SchemaRef, + output_ordering: Vec, + plan_properties: PlanProperties, + // Store the requested ordering for display + requested_ordering: Option, +} + +impl TestScan { + /// Create a new TestScan with the given schema and output ordering + pub fn new(schema: SchemaRef, output_ordering: Vec) -> Self { + let eq_properties = if !output_ordering.is_empty() { + // Convert Vec to the format expected by new_with_orderings + // We need to extract the inner Vec from each LexOrdering + let orderings: Vec> = output_ordering + .iter() + .map(|lex_ordering| { + // LexOrdering implements IntoIterator, so we can collect it + lex_ordering.iter().cloned().collect() + }) + .collect(); + + EquivalenceProperties::new_with_orderings(Arc::clone(&schema), orderings) + } else { + EquivalenceProperties::new(Arc::clone(&schema)) + }; + + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Self { + schema, + output_ordering, + plan_properties, + requested_ordering: None, + } + } + + /// Create a TestScan with a single output ordering + pub fn with_ordering(schema: SchemaRef, ordering: LexOrdering) -> Self { + Self::new(schema, vec![ordering]) + } +} + +impl DisplayAs for TestScan { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "TestScan")?; + if !self.output_ordering.is_empty() { + write!(f, ": output_ordering=[")?; + // Format the ordering in a readable way + for (i, sort_expr) in self.output_ordering[0].iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + } + // This is the key part - show what ordering was requested + if let Some(ref req) = self.requested_ordering { + write!(f, ", requested_ordering=[")?; + for (i, sort_expr) in req.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + } + Ok(()) + } + DisplayFormatType::TreeRender => { + write!(f, "TestScan") + } + } + } +} + +impl ExecutionPlan for TestScan { + fn name(&self) -> &str { + "TestScan" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(self) + } else { + internal_err!("TestScan should have no children") + } + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + internal_err!("TestScan is for testing optimizer only, not for execution") + } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema)) + } + + // This is the key method - implement sort pushdown + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // For testing purposes, accept ANY ordering request + // and create a new TestScan that shows what was requested + let requested_ordering = LexOrdering::new(order.to_vec()); + + let mut new_scan = self.clone(); + new_scan.requested_ordering = requested_ordering; + + // Always return Inexact to keep the Sort node (like Phase 1 behavior) + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(new_scan), + }) + } +} + +/// Helper function to create a TestScan with ordering +pub fn test_scan_with_ordering( + schema: SchemaRef, + ordering: LexOrdering, +) -> Arc { + Arc::new(TestScan::with_ordering(schema, ordering)) +} diff --git a/datafusion/core/tests/physical_optimizer/window_optimize.rs b/datafusion/core/tests/physical_optimizer/window_optimize.rs index fc1e6444d756e..796f6b6259716 100644 --- a/datafusion/core/tests/physical_optimizer/window_optimize.rs +++ b/datafusion/core/tests/physical_optimizer/window_optimize.rs @@ -26,10 +26,10 @@ mod test { use datafusion_expr::WindowFrame; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::expressions::{Column, col}; use datafusion_physical_expr::window::PlainAggregateWindowExpr; use datafusion_physical_plan::windows::BoundedWindowAggExec; - use datafusion_physical_plan::{common, ExecutionPlan, InputOrderMode}; + use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, common}; use std::sync::Arc; /// Test case for diff --git a/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs b/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs deleted file mode 100644 index c3c92a9028d67..0000000000000 --- a/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs +++ /dev/null @@ -1,363 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use arrow::array::RecordBatch; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use bytes::{BufMut, BytesMut}; -use datafusion::common::Result; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::{ - ArrowSource, CsvSource, FileSource, JsonSource, ParquetSource, -}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; -use datafusion_common::ColumnStatistics; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, -}; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::object_store::ObjectStoreUrl; -use object_store::{memory::InMemory, path::Path, ObjectStore}; -use parquet::arrow::ArrowWriter; - -async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { - let mut out = BytesMut::new().writer(); - { - let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - } - let data = out.into_inner().freeze(); - store.put(&Path::from(path), data.into()).await.unwrap(); -} - -/// A schema adapter factory that transforms column names to uppercase -#[derive(Debug, PartialEq)] -struct UppercaseAdapterFactory {} - -impl SchemaAdapterFactory for UppercaseAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(UppercaseAdapter { - table_schema: projected_table_schema, - }) - } -} - -/// Schema adapter that transforms column names to uppercase -#[derive(Debug)] -struct UppercaseAdapter { - table_schema: SchemaRef, -} - -impl SchemaAdapter for UppercaseAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - let uppercase_name = field.name().to_uppercase(); - file_schema - .fields() - .iter() - .position(|f| f.name().to_uppercase() == uppercase_name) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let mut projection = Vec::new(); - - // Map each field in the table schema to the corresponding field in the file schema - for table_field in self.table_schema.fields() { - let uppercase_name = table_field.name().to_uppercase(); - if let Some(pos) = file_schema - .fields() - .iter() - .position(|f| f.name().to_uppercase() == uppercase_name) - { - projection.push(pos); - } - } - - let mapper = UppercaseSchemaMapper { - output_schema: self.output_schema(), - projection: projection.clone(), - }; - - Ok((Arc::new(mapper), projection)) - } -} - -impl UppercaseAdapter { - fn output_schema(&self) -> SchemaRef { - let fields: Vec = self - .table_schema - .fields() - .iter() - .map(|f| { - Field::new( - f.name().to_uppercase().as_str(), - f.data_type().clone(), - f.is_nullable(), - ) - }) - .collect(); - - Arc::new(Schema::new(fields)) - } -} - -#[derive(Debug)] -struct UppercaseSchemaMapper { - output_schema: SchemaRef, - projection: Vec, -} - -impl SchemaMapper for UppercaseSchemaMapper { - fn map_batch(&self, batch: RecordBatch) -> Result { - let columns = self - .projection - .iter() - .map(|&i| batch.column(i).clone()) - .collect::>(); - Ok(RecordBatch::try_new(self.output_schema.clone(), columns)?) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - Ok(self - .projection - .iter() - .map(|&i| stats.get(i).cloned().unwrap_or_default()) - .collect()) - } -} - -#[cfg(feature = "parquet")] -#[tokio::test] -async fn test_parquet_integration_with_schema_adapter() -> Result<()> { - // Create test data - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])), - vec![ - Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), - Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), - ], - )?; - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch.clone(), store.clone(), path).await; - - // Get the actual file size from the object store - let object_meta = store.head(&Path::from(path)).await?; - let file_size = object_meta.size; - - // Create a session context and register the object store - let ctx = SessionContext::new(); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - - // Create a ParquetSource with the adapter factory - let file_source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; - - // Create a table schema with uppercase column names - let table_schema = Arc::new(Schema::new(vec![ - Field::new("ID", DataType::Int32, false), - Field::new("NAME", DataType::Utf8, true), - ])); - - let config = FileScanConfigBuilder::new(store_url, table_schema.clone(), file_source) - .with_file(PartitionedFile::new(path, file_size)) - .build(); - - // Create a data source executor - let exec = DataSourceExec::from_data_source(config); - - // Collect results - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx)?; - let batches = datafusion::physical_plan::common::collect(stream).await?; - - // There should be one batch - assert_eq!(batches.len(), 1); - - // Verify the schema has the uppercase column names - let result_schema = batches[0].schema(); - assert_eq!(result_schema.field(0).name(), "ID"); - assert_eq!(result_schema.field(1).name(), "NAME"); - - Ok(()) -} - -#[cfg(feature = "parquet")] -#[tokio::test] -async fn test_parquet_integration_with_schema_adapter_and_expression_rewriter( -) -> Result<()> { - // Create test data - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])), - vec![ - Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), - Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), - ], - )?; - - let store = Arc::new(InMemory::new()) as Arc; - let store_url = ObjectStoreUrl::parse("memory://").unwrap(); - let path = "test.parquet"; - write_parquet(batch.clone(), store.clone(), path).await; - - // Get the actual file size from the object store - let object_meta = store.head(&Path::from(path)).await?; - let file_size = object_meta.size; - - // Create a session context and register the object store - let ctx = SessionContext::new(); - ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); - - // Create a ParquetSource with the adapter factory - let file_source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; - - let config = FileScanConfigBuilder::new(store_url, batch.schema(), file_source) - .with_file(PartitionedFile::new(path, file_size)) - .build(); - - // Create a data source executor - let exec = DataSourceExec::from_data_source(config); - - // Collect results - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx)?; - let batches = datafusion::physical_plan::common::collect(stream).await?; - - // There should be one batch - assert_eq!(batches.len(), 1); - - // Verify the schema has the original column names (schema adapter not applied in DataSourceExec) - let result_schema = batches[0].schema(); - assert_eq!(result_schema.field(0).name(), "id"); - assert_eq!(result_schema.field(1).name(), "name"); - - Ok(()) -} - -#[tokio::test] -async fn test_multi_source_schema_adapter_reuse() -> Result<()> { - // This test verifies that the same schema adapter factory can be reused - // across different file source types. This is important for ensuring that: - // 1. The schema adapter factory interface works uniformly across all source types - // 2. The factory can be shared and cloned efficiently using Arc - // 3. Various data source implementations correctly implement the schema adapter factory pattern - - // Create a test factory - let factory = Arc::new(UppercaseAdapterFactory {}); - - // Test ArrowSource - { - let source = ArrowSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test ParquetSource - #[cfg(feature = "parquet")] - { - let source = ParquetSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test CsvSource - { - let source = CsvSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - // Test JsonSource - { - let source = JsonSource::default(); - let source_with_adapter = source - .clone() - .with_schema_adapter_factory(factory.clone()) - .unwrap(); - - let base_source: Arc = source.into(); - assert!(base_source.schema_adapter_factory().is_none()); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); - assert_eq!( - format!("{:?}", retrieved_factory.as_ref()), - format!("{:?}", factory.as_ref()) - ); - } - - Ok(()) -} diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 4b421b5294e01..d1b376b735ab9 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -365,7 +365,7 @@ async fn count_distinct_dictionary_all_null_values() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+---------------+ | cnt | count(t.num2) | +-----+---------------+ @@ -375,7 +375,7 @@ async fn count_distinct_dictionary_all_null_values() -> Result<()> { | 0 | 1 | | 0 | 1 | +-----+---------------+ - "### + " ); // Test with multiple partitions @@ -430,13 +430,13 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------------------------+ | count(DISTINCT t.dict) | +------------------------+ | 2 | +------------------------+ - "### + " ); Ok(()) diff --git a/datafusion/core/tests/sql/aggregates/dict_nulls.rs b/datafusion/core/tests/sql/aggregates/dict_nulls.rs index da4b2c8d25c9d..f9e15a71a20f8 100644 --- a/datafusion/core/tests/sql/aggregates/dict_nulls.rs +++ b/datafusion/core/tests/sql/aggregates/dict_nulls.rs @@ -34,7 +34,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { assert_snapshot!( batches_to_string(&results_count), - @r###" + @r" +----------------+-----+ | dict_null_keys | cnt | +----------------+-----+ @@ -42,7 +42,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { | group_a | 2 | | group_b | 1 | +----------------+-----+ - "### + " ); // Test SUM null handling with extended data @@ -69,7 +69,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { assert_snapshot!( batches_to_string(&results_min), - @r###" + @r" +----------------+---------+ | dict_null_keys | minimum | +----------------+---------+ @@ -78,7 +78,7 @@ async fn test_aggregates_null_handling_comprehensive() -> Result<()> { | group_b | 1 | | group_c | 7 | +----------------+---------+ - "### + " ); // Test MEDIAN null handling with median data @@ -168,7 +168,7 @@ async fn test_first_last_value_order_by_null_handling() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------------+-------+--------------------+---------------------+-------------------+--------------------+ | dict_group | value | first_ignore_nulls | first_respect_nulls | last_ignore_nulls | last_respect_nulls | +------------+-------+--------------------+---------------------+-------------------+--------------------+ @@ -178,7 +178,7 @@ async fn test_first_last_value_order_by_null_handling() -> Result<()> { | group_a | | 5 | | 20 | | | group_b | | 5 | | 20 | | +------------+-------+--------------------+---------------------+-------------------+--------------------+ - "### + " ); Ok(()) @@ -249,7 +249,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----------------+-----------+----------+-----+ | dict_null_keys | first_val | last_val | cnt | +----------------+-----------+----------+-----+ @@ -257,7 +257,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { | group_a | 10 | 50 | 2 | | group_b | 30 | 30 | 1 | +----------------+-----------+----------+-----+ - "### + " ); // Test GROUP BY with null values in dictionary @@ -275,7 +275,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results2), - @r###" + @r" +----------------+-----------+----------+-----+ | dict_null_vals | first_val | last_val | cnt | +----------------+-----------+----------+-----+ @@ -283,7 +283,7 @@ async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { | val_x | 10 | 50 | 2 | | val_y | 30 | 30 | 1 | +----------------+-----------+----------+-----+ - "### + " ); Ok(()) @@ -394,7 +394,7 @@ async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +--------+----------+---------------------+------+------+ | u8_low | utf8_low | dictionary_utf8_low | col1 | col2 | +--------+----------+---------------------+------+------+ @@ -405,7 +405,7 @@ async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { | 20 | text_e | | 0 | 1 | | 25 | text_f | group_gamma | 1 | 1 | +--------+----------+---------------------+------+------+ - "### + " ); Ok(()) diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs index 321c158628e43..ede40d5c4ceca 100644 --- a/datafusion/core/tests/sql/aggregates/mod.rs +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -20,15 +20,15 @@ use super::*; use arrow::{ array::{ - types::UInt32Type, Decimal128Array, DictionaryArray, DurationNanosecondArray, - Int32Array, LargeBinaryArray, StringArray, TimestampMicrosecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, DurationNanosecondArray, Int32Array, + LargeBinaryArray, StringArray, TimestampMicrosecondArray, UInt8Array, + UInt16Array, UInt32Array, UInt64Array, types::UInt32Type, }, datatypes::{DataType, Field, Schema, TimeUnit}, record_batch::RecordBatch, }; use datafusion::{ - common::{test_util::batches_to_string, Result}, + common::{Result, test_util::batches_to_string}, execution::{config::SessionConfig, context::SessionContext}, }; use datafusion_catalog::MemTable; @@ -959,8 +959,8 @@ impl FuzzTimestampTestData { } /// Sets up test contexts for fuzz table with timestamps and both single and multiple partitions -pub async fn setup_fuzz_timestamp_test_contexts( -) -> Result<(SessionContext, SessionContext)> { +pub async fn setup_fuzz_timestamp_test_contexts() +-> Result<(SessionContext, SessionContext)> { let test_data = FuzzTimestampTestData::new(); // Single partition context diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 26b71b5496f29..75cd78e47aff5 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -61,12 +61,9 @@ async fn explain_analyze_baseline_metrics() { assert_metrics!( &formatted, "AggregateExec: mode=Partial, gby=[]", - "metrics=[output_rows=3, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "AggregateExec: mode=Partial, gby=[]", - "output_bytes=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); assert_metrics!( @@ -75,59 +72,68 @@ async fn explain_analyze_baseline_metrics() { "reduction_factor=5.1% (5/99)" ); - assert_metrics!( - &formatted, - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "output_bytes=" - ); - assert_metrics!( - &formatted, - "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "metrics=[output_rows=99, elapsed_compute=" - ); + { + let expected_batch_count_after_repartition = + if cfg!(not(feature = "force_hash_collisions")) { + "output_batches=3" + } else { + "output_batches=1" + }; + + assert_metrics!( + &formatted, + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "RepartitionExec: partitioning=Hash([c1@0], 3), input_partitions=3", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "ProjectionExec: expr=[]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + } + assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "output_bytes=" + "metrics=[output_rows=99, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); + assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", "selectivity=99% (99/100)" ); - assert_metrics!( - &formatted, - "ProjectionExec: expr=[]", - "metrics=[output_rows=5, elapsed_compute=" - ); - assert_metrics!(&formatted, "ProjectionExec: expr=[]", "output_bytes="); - assert_metrics!( - &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "metrics=[output_rows=5, elapsed_compute" - ); - assert_metrics!( - &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "output_bytes=" - ); + assert_metrics!( &formatted, "UnionExec", - "metrics=[output_rows=3, elapsed_compute=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); - assert_metrics!(&formatted, "UnionExec", "output_bytes="); + assert_metrics!( &formatted, "WindowAggExec", - "metrics=[output_rows=1, elapsed_compute=" + "metrics=[output_rows=1, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); - assert_metrics!(&formatted, "WindowAggExec", "output_bytes="); fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; @@ -228,9 +234,13 @@ async fn explain_analyze_level() { for (level, needle, should_contain) in [ (ExplainAnalyzeLevel::Summary, "spill_count", false), + (ExplainAnalyzeLevel::Summary, "output_batches", false), (ExplainAnalyzeLevel::Summary, "output_rows", true), + (ExplainAnalyzeLevel::Summary, "output_bytes", true), (ExplainAnalyzeLevel::Dev, "spill_count", true), (ExplainAnalyzeLevel::Dev, "output_rows", true), + (ExplainAnalyzeLevel::Dev, "output_bytes", true), + (ExplainAnalyzeLevel::Dev, "output_batches", true), ] { let plan = collect_plan(sql, level).await; assert_eq!( @@ -336,12 +346,12 @@ async fn csv_explain_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int64(10) TableScan: aggregate_test_100 - "### + " ); // // verify the grahviz format of the plan @@ -407,13 +417,12 @@ async fn csv_explain_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int8(10) TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] - - "### + " ); // // verify the grahviz format of the plan @@ -553,12 +562,12 @@ async fn csv_explain_verbose_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int64(10) TableScan: aggregate_test_100 - "### + " ); // // verify the grahviz format of the plan @@ -624,12 +633,12 @@ async fn csv_explain_verbose_plans() { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Explain Projection: aggregate_test_100.c1 Filter: aggregate_test_100.c2 > Int8(10) TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] - "### + " ); // // verify the grahviz format of the plan @@ -748,19 +757,17 @@ async fn test_physical_plan_display_indent() { assert_snapshot!( actual, - @r###" + @r" SortPreservingMergeExec: [the_min@2 DESC], fetch=10 SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true] ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min] AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 - AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] - CoalesceBatchesExec: target_batch_size=4096 - FilterExec: c12@1 < 10 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true - "### + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + FilterExec: c12@1 < 10 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true + " ); } @@ -794,19 +801,13 @@ async fn test_physical_plan_display_indent_multi_children() { assert_snapshot!( actual, - @r###" - CoalesceBatchesExec: target_batch_size=4096 - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true - CoalesceBatchesExec: target_batch_size=4096 - RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000 - RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 - ProjectionExec: expr=[c1@0 as c2] - DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true - "### + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1@0 as c2], file_type=csv, has_header=true + " ); } @@ -845,8 +846,7 @@ async fn csv_explain_analyze_order_by() { // Ensure that the ordering is not optimized away from the plan // https://github.com/apache/datafusion/issues/6379 - let needle = - "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; + let needle = "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; assert_contains!(&formatted, needle); } @@ -872,6 +872,7 @@ async fn parquet_explain_analyze() { &formatted, "row_groups_pruned_statistics=1 total \u{2192} 1 matched" ); + assert_contains!(&formatted, "scan_efficiency_ratio=14%"); // The order of metrics is expected to be the same as the actual pruning order // (file-> row-group -> page) @@ -885,7 +886,7 @@ async fn parquet_explain_analyze() { (i_file < i_rowgroup_stat) && (i_rowgroup_stat < i_rowgroup_bloomfilter) && (i_rowgroup_bloomfilter < i_page), - "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." + "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." ); } @@ -997,16 +998,14 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec ProjectionExec: expr=[id@0 as id, 1 as level] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + FilterExec: id@0 = 1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] CoalescePartitionsExec ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 < 10 - RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 - WorkTableExec: name=number_series + FilterExec: id@0 < 10 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + WorkTableExec: name=number_series " ); @@ -1082,11 +1081,11 @@ async fn explain_physical_plan_only() { assert_snapshot!( actual, - @r###" + @r" physical_plan ProjectionExec: expr=[2 as count(*)] PlaceholderRowExec - "### + " ); } @@ -1140,3 +1139,24 @@ async fn nested_loop_join_selectivity() { ); } } + +#[tokio::test] +async fn explain_analyze_hash_join() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(a) \ + JOIN generate_series(20) as t2(b) \ + ON t1.a=t2.b"; + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "probe_hit_rate", true), + (ExplainAnalyzeLevel::Summary, "avg_fanout", true), + ] { + let plan = collect_plan(sql, level).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 7a59834475920..7c0e89ee96418 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -38,14 +38,16 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -72,14 +74,10 @@ async fn join_change_in_planner() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] " ); Ok(()) @@ -99,14 +97,16 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -133,14 +133,10 @@ async fn join_no_order_on_filter() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] " ); Ok(()) @@ -176,14 +172,10 @@ async fn join_change_in_planner_without_sort() -> Result<()> { actual, @r" SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 - RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 - StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true " ); Ok(()) @@ -214,7 +206,10 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.strip_backtrace(), "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + assert_eq!( + e.strip_backtrace(), + "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag" + ) } } Ok(()) @@ -295,16 +290,12 @@ async fn unparse_cross_join() -> Result<()> { .await?; let unopt_sql = plan_to_sql(df.logical_plan())?; - assert_snapshot!(unopt_sql, @r#" - SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) - "#); + assert_snapshot!(unopt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); let optimized_plan = df.into_optimized_plan()?; let opt_sql = plan_to_sql(&optimized_plan)?; - assert_snapshot!(opt_sql, @r#" - SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) - "#); + assert_snapshot!(opt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); Ok(()) } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 743c8750b5215..9a1dc5502ee60 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -24,10 +24,10 @@ use arrow::{ use datafusion::error::Result; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlanVisitor; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::*; use datafusion::test_util; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; @@ -40,18 +40,24 @@ use std::io::Write; use std::path::PathBuf; use tempfile::TempDir; -/// A macro to assert that some particular line contains two substrings +/// A macro to assert that some particular line contains the given substrings /// -/// Usage: `assert_metrics!(actual, operator_name, metrics)` +/// Usage: `assert_metrics!(actual, operator_name, metrics_1, metrics_2, ...)` macro_rules! assert_metrics { - ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { + ($ACTUAL: expr, $OPERATOR_NAME: expr, $($METRICS: expr),+) => { let found = $ACTUAL .lines() - .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); + .any(|line| line.contains($OPERATOR_NAME) $( && line.contains($METRICS))+); + + let mut metrics = String::new(); + $(metrics.push_str(format!(" '{}',", $METRICS).as_str());)+ + // remove the last `,` from the string + metrics.pop(); + assert!( found, - "Can not find a line with both '{}' and '{}' in\n\n{}", - $OPERATOR_NAME, $METRICS, $ACTUAL + "Cannot find a line with operator name '{}' and metrics containing values {} in :\n\n{}", + $OPERATOR_NAME, metrics, $ACTUAL ); }; } @@ -64,6 +70,7 @@ mod path_partition; mod runtime_config; pub mod select; mod sql_api; +mod unparser; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = test_util::arrow_test_data(); @@ -329,8 +336,7 @@ async fn nyc() -> Result<()> { match &optimized_plan { LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. + projected_schema, .. }) => { assert_eq!(2, projected_schema.fields().len()); assert_eq!(projected_schema.field(0).name(), "passenger_count"); diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 05cc723ef05fb..c6f920584dc2b 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -31,14 +31,13 @@ use datafusion::{ listing::{ListingOptions, ListingTable, ListingTableConfig}, }, error::Result, - physical_plan::ColumnStatistics, prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; use datafusion_catalog::TableProvider; +use datafusion_common::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; -use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; use async_trait::async_trait; @@ -46,11 +45,11 @@ use bytes::Bytes; use chrono::{TimeZone, Utc}; use futures::stream::{self, BoxStream}; use insta::assert_snapshot; +use object_store::{Attributes, MultipartUpload, PutMultipartOptions, PutPayload}; use object_store::{ - path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, - ObjectStore, PutOptions, PutResult, + GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutOptions, PutResult, path::Path, }; -use object_store::{Attributes, MultipartUpload, PutMultipartOptions, PutPayload}; use url::Url; #[tokio::test] @@ -464,10 +463,19 @@ async fn parquet_statistics() -> Result<()> { assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); - // TODO assert partition column (1,2,3) stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[2], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[3], ColumnStatistics::new_unknown(),); + // Partition column statistics (year=2021 for all 3 rows) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + // month and day are Utf8 partition columns with statistics + assert_eq!(stat_cols[2].null_count, Precision::Exact(0)); + assert_eq!(stat_cols[3].null_count, Precision::Exact(0)); //// WITH PROJECTION //// let dataframe = ctx.sql("SELECT mycol, day FROM t WHERE day='28'").await?; @@ -479,8 +487,16 @@ async fn parquet_statistics() -> Result<()> { assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); - // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown()); + // Partition column statistics for day='28' (1 row) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); Ok(()) } diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs index 9627d7bccdb04..d85892c254570 100644 --- a/datafusion/core/tests/sql/runtime_config.rs +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -18,9 +18,14 @@ //! Tests for runtime configuration SQL interface use std::sync::Arc; +use std::time::Duration; use datafusion::execution::context::SessionContext; use datafusion::execution::context::TaskContext; +use datafusion::prelude::SessionConfig; +use datafusion_execution::cache::DefaultListFilesCache; +use datafusion_execution::cache::cache_manager::CacheManagerConfig; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_physical_plan::common::collect; #[tokio::test] @@ -233,6 +238,93 @@ async fn test_test_metadata_cache_limit() { assert_eq!(get_limit(&ctx), 123 * 1024); } +#[tokio::test] +async fn test_list_files_cache_limit() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_limit = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> usize { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_limit() + }; + + update_limit(&ctx, "100M").await; + assert_eq!(get_limit(&ctx), 100 * 1024 * 1024); + + update_limit(&ctx, "2G").await; + assert_eq!(get_limit(&ctx), 2 * 1024 * 1024 * 1024); + + update_limit(&ctx, "123K").await; + assert_eq!(get_limit(&ctx), 123 * 1024); +} + +#[tokio::test] +async fn test_list_files_cache_ttl() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_ttl = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> Duration { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_ttl() + .unwrap() + }; + + update_limit(&ctx, "1m").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(60)); + + update_limit(&ctx, "30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(30)); + + update_limit(&ctx, "1m30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(90)); +} + #[tokio::test] async fn test_unknown_runtime_config() { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 8a0f620627384..6126793145efd 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -18,8 +18,7 @@ use std::collections::HashMap; use super::*; -use datafusion::assert_batches_eq; -use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue}; +use datafusion_common::{ParamValues, ScalarValue, metadata::ScalarAndMetadata}; use insta::assert_snapshot; #[tokio::test] @@ -223,10 +222,10 @@ async fn test_parameter_invalid_types() -> Result<()> { .await; assert_snapshot!(results.unwrap_err().strip_backtrace(), @r" - type_coercion - caused by - Error during planning: Cannot infer common argument type for comparison operation List(nullable Int32) = Int32 - "); + type_coercion + caused by + Error during planning: Cannot infer common argument type for comparison operation List(Int32) = Int32 + "); Ok(()) } @@ -343,26 +342,20 @@ async fn test_query_parameters_with_metadata() -> Result<()> { ])) .unwrap(); - // df_with_params_replaced.schema() is not correct here - // https://github.com/apache/datafusion/issues/18102 - let batches = df_with_params_replaced.clone().collect().await.unwrap(); - let schema = batches[0].schema(); - + let schema = df_with_params_replaced.schema(); assert_eq!(schema.field(0).data_type(), &DataType::UInt32); assert_eq!(schema.field(0).metadata(), &metadata1); assert_eq!(schema.field(1).data_type(), &DataType::Utf8); assert_eq!(schema.field(1).metadata(), &metadata2); - assert_batches_eq!( - [ - "+----+-----+", - "| $1 | $2 |", - "+----+-----+", - "| 1 | two |", - "+----+-----+", - ], - &batches - ); + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-----+ + | $1 | $2 | + +----+-----+ + | 1 | two | + +----+-----+ + "); Ok(()) } @@ -421,3 +414,20 @@ async fn test_select_no_projection() -> Result<()> { "); Ok(()) } + +#[tokio::test] +async fn test_select_cast_date_literal_to_timestamp_overflow() -> Result<()> { + let ctx = SessionContext::new(); + let err = ctx + .sql("SELECT CAST(DATE '9999-12-31' AS TIMESTAMP)") + .await? + .collect() + .await + .unwrap_err(); + + assert_contains!( + err.to_string(), + "Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range" + ); + Ok(()) +} diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs new file mode 100644 index 0000000000000..8b56bf67a261c --- /dev/null +++ b/datafusion/core/tests/sql/unparser.rs @@ -0,0 +1,462 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Unparser Roundtrip Integration Tests +//! +//! This module tests the [`Unparser`] by running queries through a complete roundtrip: +//! the original SQL is parsed into a logical plan, unparsed back to SQL, then that +//! generated SQL is parsed and executed. The results are compared to verify semantic +//! equivalence. +//! +//! ## Test Strategy +//! +//! Uses real-world benchmark queries (TPC-H and Clickbench) to validate that: +//! 1. The unparser produces syntactically valid SQL +//! 2. The unparsed SQL is semantically equivalent (produces identical results) +//! +//! ## Query Suites +//! +//! - **TPC-H**: Standard decision-support benchmark with 22 complex analytical queries +//! - **Clickbench**: Web analytics benchmark with 43 queries against a denormalized schema +//! +//! [`Unparser`]: datafusion_sql::unparser::Unparser + +use std::fs::ReadDir; +use std::future::Future; + +use arrow::array::RecordBatch; +use datafusion::common::Result; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::Column; +use datafusion_expr::Expr; +use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_sql::unparser::Unparser; +use datafusion_sql::unparser::dialect::DefaultDialect; +use itertools::Itertools; + +/// Paths to benchmark query files (supports running from repo root or different working directories). +const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"]; + +/// Reads all `.sql` files from a directory and converts them to test queries. +/// +/// Skips files that: +/// - Are not regular files +/// - Don't have a `.sql` extension +/// - Contain multiple SQL statements (indicated by `;\n`) +/// +/// Multi-statement files are skipped because the unparser doesn't support +/// DML statements like `CREATE VIEW` that appear in multi-statement Clickbench queries. +fn iterate_queries(dir: ReadDir) -> Vec { + let mut queries = vec![]; + for entry in dir.flatten() { + let Ok(file_type) = entry.file_type() else { + continue; + }; + if !file_type.is_file() { + continue; + } + let path = entry.path(); + let Some(ext) = path.extension() else { + continue; + }; + if ext != "sql" { + continue; + } + let name = path.file_stem().unwrap().to_string_lossy().to_string(); + if let Ok(mut contents) = std::fs::read_to_string(entry.path()) { + // If the query contains ;\n it has DML statements like CREATE VIEW which the unparser doesn't support; skip it + contents = contents.trim().to_string(); + if contents.contains(";\n") { + println!("Skipping query with multiple statements: {name}"); + continue; + } + queries.push(TestQuery { + sql: contents, + name, + }); + } + } + queries +} + +/// A SQL query loaded from a benchmark file for roundtrip testing. +/// +/// Each query is identified by its filename (without extension) and contains +/// the full SQL text to be tested. +struct TestQuery { + /// The SQL query text to test. + sql: String, + /// The query identifier (typically the filename without .sql extension). + name: String, +} + +/// Collect SQL for Clickbench queries. +fn clickbench_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/clickbench/queries/"); + println!("Reading Clickbench queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + println!("Found {} Clickbench queries", read.len()); + queries.extend(read); + } + } + queries.sort_unstable_by_key(|q| { + q.name + .split('q') + .next_back() + .and_then(|num| num.parse::().ok()) + }); + queries +} + +/// Collect SQL for TPC-H queries. +fn tpch_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/"); + println!("Reading TPC-H queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + queries.extend(read); + } + } + println!("Total TPC-H queries found: {}", queries.len()); + queries.sort_unstable_by_key(|q| q.name.clone()); + queries +} + +/// Create a new SessionContext for testing that has all Clickbench tables registered. +async fn clickbench_test_context() -> Result { + let ctx = SessionContext::new(); + ctx.register_parquet( + "hits", + "tests/data/clickbench_hits_10.parquet", + ParquetReadOptions::default(), + ) + .await?; + // Sanity check we found the table by querying it's schema, it should not be empty + // Otherwise if the path is wrong the tests will all fail in confusing ways + let df = ctx.sql("SELECT * FROM hits LIMIT 1").await?; + assert!( + !df.schema().fields().is_empty(), + "Clickbench 'hits' table not registered correctly" + ); + Ok(ctx) +} + +/// Create a new SessionContext for testing that has all TPC-H tables registered. +async fn tpch_test_context() -> Result { + let ctx = SessionContext::new(); + let data_dir = "tests/data/"; + // All tables have the pattern "tpch__small.parquet" + for table in [ + "customer", "lineitem", "nation", "orders", "part", "partsupp", "region", + "supplier", + ] { + let path = format!("{data_dir}tpch_{table}_small.parquet"); + ctx.register_parquet(table, &path, ParquetReadOptions::default()) + .await?; + // Sanity check we found the table by querying it's schema, it should not be empty + // Otherwise if the path is wrong the tests will all fail in confusing ways + let df = ctx.sql(&format!("SELECT * FROM {table} LIMIT 1")).await?; + assert!( + !df.schema().fields().is_empty(), + "TPC-H '{table}' table not registered correctly" + ); + } + Ok(ctx) +} + +/// Sorts record batches by all columns for deterministic comparison. +/// +/// When comparing query results, we need a canonical ordering so that +/// semantically equivalent results compare as equal. This function sorts +/// by all columns in the schema to achieve that. +async fn sort_batches( + ctx: &SessionContext, + batches: Vec, +) -> Result> { + let mut df = ctx.read_batches(batches)?; + let schema = df.schema().as_arrow().clone(); + let sort_exprs = schema + .fields() + .iter() + // Use Column directly, col() causes the column names to be normalized to lowercase + .map(|f| { + Expr::Column(Column::new_unqualified(f.name().to_string())).sort(true, false) + }) + .collect_vec(); + if !sort_exprs.is_empty() { + df = df.sort(sort_exprs)?; + } + df.collect().await +} + +/// The outcome of running a single roundtrip test. +/// +/// A successful test produces [`TestCaseResult::Success`]. +/// All other variants capture different failure modes with enough context to diagnose the issue. +enum TestCaseResult { + /// The unparsed SQL produced identical results to the original. + Success, + + /// Both queries executed but produced different results. + /// + /// This indicates a semantic bug in the unparser where the generated SQL + /// has different meaning than the original. + ResultsMismatch { original: String, unparsed: String }, + + /// The unparser failed to convert the logical plan to SQL. + /// + /// This may indicate an unsupported SQL feature or a bug in the unparser. + UnparseError { original: String, error: String }, + + /// The original SQL failed to execute. + /// + /// This indicates a problem with the test setup (missing tables, + /// invalid test data) rather than an unparser issue. + ExecutionError { original: String, error: String }, + + /// The unparsed SQL failed to execute, even though the original succeeded. + /// + /// This indicates the unparser generated syntactically invalid SQL or SQL + /// that references non-existent columns/tables. + UnparsedExecutionError { + original: String, + unparsed: String, + error: String, + }, +} + +impl TestCaseResult { + /// Returns true if the test case represents a failure + /// (anything other than [`TestCaseResult::Success`]). + fn is_failure(&self) -> bool { + !matches!(self, TestCaseResult::Success) + } + + /// Formats a detailed error message for the test case into a string. + fn format_error(&self, name: &str) -> String { + match self { + TestCaseResult::Success => String::new(), + TestCaseResult::ResultsMismatch { original, unparsed } => { + format!( + "Results mismatch for {name}.\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + TestCaseResult::UnparseError { original, error } => { + format!("Unparse error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::ExecutionError { original, error } => { + format!("Execution error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::UnparsedExecutionError { + original, + unparsed, + error, + } => { + format!( + "Unparsed execution error for {name}: {error}\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + } + } +} + +/// Executes a roundtrip test for a single SQL query. +/// +/// This is the core test logic that: +/// 1. Parses the original SQL and creates a logical plan +/// 2. Unparses the logical plan back to SQL +/// 3. Executes both the original and unparsed queries +/// 4. Compares the results (sorting if the query has no ORDER BY) +/// +/// This always uses [`DefaultDialect`] for unparsing. +/// +/// # Arguments +/// +/// * `ctx` - Session context with tables registered +/// * `original` - The original SQL query to test +/// +/// # Returns +/// +/// A [`TestCaseResult`] indicating success or the specific failure mode. +async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult { + let unparser = Unparser::new(&DefaultDialect {}); + + // Parse and create logical plan from original SQL + let df = match ctx.sql(original).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Unparse the logical plan back to SQL + let unparsed = match unparser.plan_to_sql(df.logical_plan()) { + Ok(sql) => format!("{sql:#}"), + Err(e) => { + return TestCaseResult::UnparseError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + let is_sorted = match ctx.state().create_physical_plan(df.logical_plan()).await { + Ok(plan) => plan.equivalence_properties().output_ordering().is_some(), + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Collect results from original query + let mut expected = match df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Parse and execute the unparsed SQL + let actual_df = match ctx.sql(&unparsed).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Collect results from unparsed query + let mut actual = match actual_df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Sort if needed for comparison + if !is_sorted { + expected = match sort_batches(ctx, expected).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: format!("Failed to sort expected results: {e}"), + }; + } + }; + actual = match sort_batches(ctx, actual).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: format!("Failed to sort actual results: {e}"), + }; + } + }; + } + + if expected != actual { + TestCaseResult::ResultsMismatch { + original: original.to_string(), + unparsed, + } + } else { + TestCaseResult::Success + } +} + +/// Runs roundtrip tests for a collection of queries and reports results. +/// +/// Iterates through all queries, running each through [`collect_results`]. +/// Prints colored status (green checkmark for success, red X for failure) +/// and panics at the end if any tests failed, with detailed error messages. +/// +/// # Type Parameters +/// +/// * `F` - Factory function that creates fresh session contexts +/// * `Fut` - Future type returned by the context factory +/// +/// # Panics +/// +/// Panics if any query fails the roundtrip test, displaying all failures. +async fn run_roundtrip_tests( + suite_name: &str, + queries: Vec, + create_context: F, +) where + F: Fn() -> Fut, + Fut: Future>, +{ + let mut errors: Vec = vec![]; + for sql in queries { + let ctx = match create_context().await { + Ok(ctx) => ctx, + Err(e) => { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(format!("Failed to create context for {}: {}", sql.name, e)); + continue; + } + }; + let result = collect_results(&ctx, &sql.sql).await; + if result.is_failure() { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(result.format_error(&sql.name)); + } else { + println!("\x1b[32m✓\x1b[0m {} query: {}", suite_name, sql.name); + } + } + if !errors.is_empty() { + panic!( + "{} {} test(s) failed:\n\n{}", + errors.len(), + suite_name, + errors.join("\n\n---\n\n") + ); + } +} + +#[tokio::test] +async fn test_clickbench_unparser_roundtrip() { + run_roundtrip_tests("Clickbench", clickbench_queries(), clickbench_test_context) + .await; +} + +#[tokio::test] +async fn test_tpch_unparser_roundtrip() { + run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await; +} diff --git a/datafusion/core/tests/tpc-ds/30.sql b/datafusion/core/tests/tpc-ds/30.sql index 78f34b807e5b5..80624f49006a9 100644 --- a/datafusion/core/tests/tpc-ds/30.sql +++ b/datafusion/core/tests/tpc-ds/30.sql @@ -14,7 +14,7 @@ with customer_total_return as ,ca_state) select c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return from customer_total_return ctr1 ,customer_address ,customer @@ -26,7 +26,7 @@ with customer_total_return as and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return limit 100; diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d92..3ad74962bc2c0 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/core/tests/tracing/asserting_tracer.rs b/datafusion/core/tests/tracing/asserting_tracer.rs index 292e066e5f121..700f9f3308466 100644 --- a/datafusion/core/tests/tracing/asserting_tracer.rs +++ b/datafusion/core/tests/tracing/asserting_tracer.rs @@ -21,7 +21,7 @@ use std::ops::Deref; use std::sync::{Arc, LazyLock}; use datafusion_common::{HashMap, HashSet}; -use datafusion_common_runtime::{set_join_set_tracer, JoinSetTracer}; +use datafusion_common_runtime::{JoinSetTracer, set_join_set_tracer}; use futures::future::BoxFuture; use tokio::sync::{Mutex, MutexGuard}; diff --git a/datafusion/core/tests/tracing/traceable_object_store.rs b/datafusion/core/tests/tracing/traceable_object_store.rs index 60ef1cc5d6b6a..00aa4ea3f36d9 100644 --- a/datafusion/core/tests/tracing/traceable_object_store.rs +++ b/datafusion/core/tests/tracing/traceable_object_store.rs @@ -20,8 +20,8 @@ use crate::tracing::asserting_tracer::assert_traceability; use futures::stream::BoxStream; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 07d289cab06c2..c5e5af731359f 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -26,9 +26,9 @@ use datafusion::logical_expr::Operator; use datafusion::prelude::*; use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion_common::ScalarValue; +use datafusion_expr::BinaryExpr; use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; -use datafusion_expr::BinaryExpr; #[derive(Debug)] struct MyCustomPlanner; @@ -77,25 +77,25 @@ async fn plan_and_collect(sql: &str) -> Result> { #[tokio::test] async fn test_custom_operators_arrow() { let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +----------------------------+ | Utf8("foo") || Utf8("bar") | +----------------------------+ | foobar | +----------------------------+ - "###); + "#); } #[tokio::test] async fn test_custom_operators_long_arrow() { let actual = plan_and_collect("select 1->>2;").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---------------------+ | Int64(1) + Int64(2) | +---------------------+ | 3 | +---------------------+ - "###); + "); } #[tokio::test] @@ -103,13 +103,13 @@ async fn test_question_select() { let actual = plan_and_collect("select a ? 2 from (select 1 as a);") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +--------------+ | a ? Int64(2) | +--------------+ | true | +--------------+ - "###); + "); } #[tokio::test] @@ -117,11 +117,11 @@ async fn test_question_filter() { let actual = plan_and_collect("select a from (select 1 as a) where a ? 2;") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index e0a3e98604ae4..7ad00dece1b24 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -25,12 +25,12 @@ use datafusion::{ }; use datafusion_catalog::{Session, TableProvider}; use datafusion_common::config::Dialect; -use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_expr::{Expr, TableType, dml::InsertOp}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::SchedulingType; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; #[tokio::test] diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 5d84cdb692830..bc9949f5d681c 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// Tests for user defined Async Scalar functions +mod user_defined_async_scalar_functions; + /// Tests for user defined Scalar functions mod user_defined_scalar_functions; @@ -33,5 +36,8 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; +/// Tests for Relation Planner extensions +mod relation_planner; + /// Tests for insert operations mod insert_operation; diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs new file mode 100644 index 0000000000000..bda9b37ebea68 --- /dev/null +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the RelationPlanner extension point + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::memory::MemTable; +use datafusion::common::test_util::batches_to_string; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Expr; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Test Planners - Example Implementations +// ============================================================================ + +// The planners in this section are deliberately minimal, static examples used +// only for tests. In real applications a `RelationPlanner` would typically +// construct richer logical plans tailored to external systems or custom +// semantics rather than hard-coded in-memory tables. +// +// For more realistic examples, see `datafusion-examples/examples/relation_planner/`: +// - `table_sample.rs`: Full TABLESAMPLE implementation (parsing → execution) +// - `pivot_unpivot.rs`: PIVOT/UNPIVOT via SQL rewriting +// - `match_recognize.rs`: MATCH_RECOGNIZE logical planning + +/// Helper to build simple static values-backed virtual tables used by the +/// example planners below. +fn plan_static_values_table( + relation: TableFactor, + table_name: &str, + column_name: &str, + values: Vec, +) -> Result { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case(table_name) => + { + let rows = values + .into_iter() + .map(|v| vec![Expr::Literal(v, None)]) + .collect::>(); + + let plan = LogicalPlanBuilder::values(rows)? + .project(vec![col("column1").alias(column_name)])? + .build()?; + + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } +} + +/// Example planner that provides a virtual `numbers` table with values +/// 1, 2, 3. +#[derive(Debug)] +struct NumbersPlanner; + +impl RelationPlanner for NumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)), + ScalarValue::Int64(Some(3)), + ], + ) + } +} + +/// Example planner that provides a virtual `colors` table with three string +/// values: `red`, `green`, `blue`. +#[derive(Debug)] +struct ColorsPlanner; + +impl RelationPlanner for ColorsPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "colors", + "color", + vec![ + ScalarValue::Utf8(Some("red".into())), + ScalarValue::Utf8(Some("green".into())), + ScalarValue::Utf8(Some("blue".into())), + ], + ) + } +} + +/// Alternative implementation of `numbers` (returns 100, 200) used to +/// demonstrate planner precedence (last registered planner wins). +#[derive(Debug)] +struct AlternativeNumbersPlanner; + +impl RelationPlanner for AlternativeNumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ScalarValue::Int64(Some(100)), ScalarValue::Int64(Some(200))], + ) + } +} + +/// Example planner that intercepts nested joins and samples both sides (limit 2) +/// before joining, demonstrating recursive planning with `context.plan()`. +#[derive(Debug)] +struct SamplingJoinPlanner; + +impl RelationPlanner for SamplingJoinPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::NestedJoin { + table_with_joins, + alias, + .. + } if table_with_joins.joins.len() == 1 => { + // Use context.plan() to recursively plan both sides + // This ensures other planners (like NumbersPlanner) can handle them + let left = context.plan(table_with_joins.relation.clone())?; + let right = context.plan(table_with_joins.joins[0].relation.clone())?; + + // Sample each table to 2 rows + let left_sampled = + LogicalPlanBuilder::from(left).limit(0, Some(2))?.build()?; + + let right_sampled = + LogicalPlanBuilder::from(right).limit(0, Some(2))?.build()?; + + // Cross join: 2 rows × 2 rows = 4 rows (instead of 3×3=9 without sampling) + let plan = LogicalPlanBuilder::from(left_sampled) + .cross_join(right_sampled)? + .build()?; + + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// Example planner that never handles any relation and always delegates by +/// returning `RelationPlanning::Original`. +#[derive(Debug)] +struct PassThroughPlanner; + +impl RelationPlanner for PassThroughPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + // Never handles anything - always delegates + Ok(RelationPlanning::Original(relation)) + } +} + +/// Example planner that shows how planners can block specific constructs and +/// surface custom error messages by rejecting `UNNEST` relations (here framed +/// as a mock premium feature check). +#[derive(Debug)] +struct PremiumFeaturePlanner; + +impl RelationPlanner for PremiumFeaturePlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::UNNEST { .. } => Err(datafusion_common::DataFusionError::Plan( + "UNNEST is a premium feature! Please upgrade to DataFusion Pro™ \ + to unlock advanced array operations." + .to_string(), + )), + other => Ok(RelationPlanning::Original(other)), + } + } +} + +// ============================================================================ +// Test Helpers - SQL Execution +// ============================================================================ + +/// Execute SQL and return results with better error messages. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result> { + let df = ctx.sql(sql).await?; + df.collect().await +} + +/// Execute SQL and convert to string format for snapshot comparison. +async fn execute_sql_to_string(ctx: &SessionContext, sql: &str) -> String { + let batches = execute_sql(ctx, sql) + .await + .expect("SQL execution should succeed"); + batches_to_string(&batches) +} + +// ============================================================================ +// Test Helpers - Context Builders +// ============================================================================ + +/// Create a SessionContext with a catalog table containing Int64 and Utf8 columns. +/// +/// Creates a table with the specified name and sample data for fallback/integration tests. +fn create_context_with_catalog_table( + table_name: &str, + id_values: Vec, + name_values: Vec<&str>, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(id_values)), + Arc::new(StringArray::from(name_values)), + ], + ) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +/// Create a SessionContext with a simple single-column Int64 table. +/// +/// Useful for basic tests that need a real catalog table. +fn create_context_with_simple_table( + table_name: &str, + values: Vec, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))]) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +// ============================================================================ +// TESTS: Ordered from Basic to Complex +// ============================================================================ + +/// Comprehensive test suite for RelationPlanner extension point. +/// Tests are ordered from simplest smoke test to most complex scenarios. +#[cfg(test)] +mod tests { + use super::*; + + /// Small extension trait to make test setup read fluently. + trait TestSessionExt { + fn with_planner(self, planner: P) -> Self; + } + + impl TestSessionExt for SessionContext { + fn with_planner(self, planner: P) -> Self { + self.register_relation_planner(Arc::new(planner)).unwrap(); + self + } + } + + /// Session context with only the `NumbersPlanner` registered. + fn ctx_with_numbers() -> SessionContext { + SessionContext::new().with_planner(NumbersPlanner) + } + + /// Session context with virtual tables (`numbers`, `colors`) and the + /// `SamplingJoinPlanner` registered for nested joins. + fn ctx_with_virtual_tables_and_sampling() -> SessionContext { + SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner) + .with_planner(SamplingJoinPlanner) + } + + // Basic smoke test: virtual table can be queried like a regular table. + #[tokio::test] + async fn virtual_table_basic_select() { + let ctx = ctx_with_numbers(); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Virtual table supports standard SQL operations (projection, filter, aggregation). + #[tokio::test] + async fn virtual_table_filters_and_aggregation() { + let ctx = ctx_with_numbers(); + + let filtered = execute_sql_to_string( + &ctx, + "SELECT number * 10 AS scaled FROM numbers WHERE number > 1", + ) + .await; + + assert_snapshot!(filtered, @r" + +--------+ + | scaled | + +--------+ + | 20 | + | 30 | + +--------+ + "); + + let aggregated = execute_sql_to_string( + &ctx, + "SELECT COUNT(*) as count, SUM(number) as total, AVG(number) as average \ + FROM numbers", + ) + .await; + + assert_snapshot!(aggregated, @r" + +-------+-------+---------+ + | count | total | average | + +-------+-------+---------+ + | 3 | 6 | 2.0 | + +-------+-------+---------+ + "); + } + + // Multiple planners can coexist and each handles its own virtual table. + #[tokio::test] + async fn multiple_planners_virtual_tables() { + let ctx = SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner); + + let result1 = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + assert_snapshot!(result1, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + + let result2 = execute_sql_to_string(&ctx, "SELECT * FROM colors").await; + assert_snapshot!(result2, @r" + +-------+ + | color | + +-------+ + | red | + | green | + | blue | + +-------+ + "); + } + + // Last registered planner for the same table name takes precedence (LIFO). + #[tokio::test] + async fn lifo_precedence_last_planner_wins() { + let ctx = SessionContext::new() + .with_planner(AlternativeNumbersPlanner) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + // CustomValuesPlanner registered last, should win (returns 1,2,3 not 100,200) + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Pass-through planner delegates to the catalog without changing behavior. + #[tokio::test] + async fn delegation_pass_through_to_catalog() { + let ctx = create_context_with_simple_table("real_table", vec![42]) + .with_planner(PassThroughPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM real_table").await; + + assert_snapshot!(result, @r" + +-------+ + | value | + +-------+ + | 42 | + +-------+ + "); + } + + // Catalog is used when no planner claims the relation. + #[tokio::test] + async fn catalog_fallback_when_no_planner() { + let ctx = + create_context_with_catalog_table("users", vec![1, 2], vec!["Alice", "Bob"]) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM users ORDER BY id").await; + + assert_snapshot!(result, @r" + +----+-------+ + | id | name | + +----+-------+ + | 1 | Alice | + | 2 | Bob | + +----+-------+ + "); + } + + // Planners can block specific constructs and surface custom error messages. + #[tokio::test] + async fn error_handling_premium_feature_blocking() { + // Verify UNNEST works without planner + let ctx_without_planner = SessionContext::new(); + let result = + execute_sql(&ctx_without_planner, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect("UNNEST should work by default"); + assert_eq!(result.len(), 1); + + // Same query with blocking planner registered + let ctx = SessionContext::new().with_planner(PremiumFeaturePlanner); + + // Verify UNNEST is now rejected + let error = execute_sql(&ctx, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect_err("UNNEST should be rejected"); + + let error_msg = error.to_string(); + assert!( + error_msg.contains("premium feature") && error_msg.contains("DataFusion Pro"), + "Expected custom rejection message, got: {error_msg}" + ); + } + + // SamplingJoinPlanner recursively calls `context.plan()` on both sides of a + // nested join before sampling, exercising recursive relation planning. + #[tokio::test] + async fn recursive_planning_sampling_join() { + let ctx = ctx_with_virtual_tables_and_sampling(); + + let result = + execute_sql_to_string(&ctx, "SELECT * FROM (numbers JOIN colors ON true)") + .await; + + // SamplingJoinPlanner limits each side to 2 rows: 2×2=4 (not 3×3=9) + assert_snapshot!(result, @r" + +--------+-------+ + | number | color | + +--------+-------+ + | 1 | red | + | 1 | green | + | 2 | red | + | 2 | green | + +--------+-------+ + "); + } +} diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 62e8ab18b9be0..e7bd2241398ad 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -23,13 +23,13 @@ use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ - atomic::{AtomicBool, Ordering}, Arc, + atomic::{AtomicBool, Ordering}, }; use arrow::array::{ - record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, - StringArray, StructArray, UInt64Array, + Array, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, UInt64Array, + record_batch, types::UInt64Type, }; use arrow::datatypes::{Fields, Schema}; use arrow_schema::FieldRef; @@ -56,8 +56,8 @@ use datafusion_common::{cast::as_primitive_array, exec_err}; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, - GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, + AggregateUDFImpl, Expr, GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, + WindowFunctionDefinition, col, create_udaf, function::AccumulatorArgs, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -69,7 +69,7 @@ async fn test_setup() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-------+----------------------------+ | value | time | +-------+----------------------------+ @@ -79,7 +79,7 @@ async fn test_setup() { | 5.0 | 1970-01-01T00:00:00.000005 | | 5.0 | 1970-01-01T00:00:00.000005 | +-------+----------------------------+ - "###); + "); } /// Basic user defined aggregate @@ -91,13 +91,13 @@ async fn test_udaf() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // normal aggregates call update_batch assert!(test_state.update_batch()); @@ -112,7 +112,7 @@ async fn test_udaf_as_window() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -122,7 +122,7 @@ async fn test_udaf_as_window() { | 1970-01-01T00:00:00.000019 | | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // aggregate over the entire window function call update_batch assert!(test_state.update_batch()); @@ -137,7 +137,7 @@ async fn test_udaf_as_window_with_frame() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -147,7 +147,7 @@ async fn test_udaf_as_window_with_frame() { | 1970-01-01T00:00:00.000014 | | 1970-01-01T00:00:00.000010 | +----------------------------+ - "###); + "); // user defined aggregates with window frame should be calling retract batch assert!(test_state.update_batch()); @@ -164,7 +164,10 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); + assert_contains!( + err.to_string(), + "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING" + ); } /// Basic query for with a udaf returning a structure @@ -175,13 +178,13 @@ async fn test_udaf_returning_struct() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------------+ | first(t.value,t.time) | +------------------------------------------------+ | {value: 2.0, time: 1970-01-01T00:00:00.000002} | +------------------------------------------------+ - "###); + "); } /// Demonstrate extracting the fields from a structure using a subquery @@ -192,13 +195,13 @@ async fn test_udaf_returning_struct_subquery() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-----------------+----------------------------+ | sq.first[value] | sq.first[time] | +-----------------+----------------------------+ | 2.0 | 1970-01-01T00:00:00.000002 | +-----------------+----------------------------+ - "###); + "); } #[tokio::test] @@ -212,13 +215,13 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +---------------------------------------+ | sum(arrow_cast(t.time,Utf8("Int64"))) | +---------------------------------------+ | 19000 | +---------------------------------------+ - "###); + "#); // Register `TimeSum` with name `sum`. This will shadow the builtin one TimeSum::register(&mut ctx, test_state.clone(), "sum"); @@ -226,13 +229,13 @@ async fn test_udaf_shadows_builtin_fn() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); } async fn execute(ctx: &SessionContext, sql: &str) -> Result> { @@ -272,13 +275,13 @@ async fn simple_udaf() -> Result<()> { let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | my_avg(t.a) | +-------------+ | 3.0 | +-------------+ - "###); + "); Ok(()) } @@ -329,9 +332,10 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { // doesn't work as it was registered as non lowercase let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_avg\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_avg\'") + ); // Can call it if you put quotes let result = ctx @@ -340,13 +344,13 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | MY_AVG(t.i) | +-------------+ | 1.0 | +-------------+ - "###); + "); Ok(()) } @@ -372,13 +376,13 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1.0 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; @@ -449,13 +453,13 @@ async fn test_parameterized_aggregate_udf() -> Result<()> { let actual = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------+---+---+ | text | a | b | +------+---+---+ | foo | 1 | 2 | +------+---+---+ - "###); + "); ctx.deregister_table("t")?; Ok(()) @@ -569,6 +573,7 @@ impl TimeSum { Self { sum: 0, test_state } } + #[expect(clippy::needless_pass_by_value)] fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); let input_type = vec![timestamp_type.clone()]; @@ -760,11 +765,11 @@ impl Accumulator for FirstSelector { // Update the actual values for (value, time) in v.iter().zip(t.iter()) { - if let (Some(time), Some(value)) = (time, value) { - if time < self.time { - self.value = value; - self.time = time; - } + if let (Some(time), Some(value)) = (time, value) + && time < self.time + { + self.value = value; + self.time = time; } } diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs new file mode 100644 index 0000000000000..168d81fc6b44c --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::prelude::*; +use datafusion_common::{Result, assert_batches_eq}; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let num_rows = 3; + let batch_size = 2; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("prompt", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from((0..num_rows).collect::>())), + Arc::new(StringArray::from( + (0..num_rows) + .map(|i| format!("prompt{i}")) + .collect::>(), + )), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test_table", batch)?; + + ctx.register_udf( + AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(batch_size))) + .into_scalar_udf(), + ); + + let df = ctx + .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") + .await?; + + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+" + ], + &result + ); + + Ok(()) +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct TestAsyncUDFImpl { + batch_size: usize, + signature: Signature, +} + +impl TestAsyncUDFImpl { + fn new(batch_size: usize) -> Self { + Self { + batch_size, + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for TestAsyncUDFImpl { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("Call invoke_async_with_args instead") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for TestAsyncUDFImpl { + fn ideal_batch_size(&self) -> Option { + Some(self.batch_size) + } + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let arg1 = &args.args[0]; + let results = call_external_service(arg1.clone()).await?; + Ok(results) + } +} + +/// Simulates calling an async external service +async fn call_external_service(arg1: ColumnarValue) -> Result { + Ok(arg1) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index ffe0ba021edb3..d53e076739608 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -70,7 +70,7 @@ use arrow::{ use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::as_int64_array, - common::{arrow_datafusion_err, internal_err, DFSchemaRef}, + common::{DFSchemaRef, arrow_datafusion_err}, error::{DataFusionError, Result}, execution::{ context::{QueryPlanner, SessionState, TaskContext}, @@ -91,10 +91,10 @@ use datafusion::{ }; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::{ScalarValue, assert_eq_or_internal_err, assert_or_internal_err}; use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr}; -use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; +use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; @@ -161,7 +161,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +-------------+---------+ | customer_id | revenue | +-------------+---------+ @@ -169,7 +169,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result | jorge | 200 | | andy | 150 | +-------------+---------+ - "###); + "); }); } @@ -188,13 +188,13 @@ async fn run_and_compare_query_with_analyzer_rule( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +------------+--------------------------+ | UInt64(42) | arrow_typeof(UInt64(42)) | +------------+--------------------------+ | 42 | UInt64 | +------------+--------------------------+ - "###); + "); }); Ok(()) @@ -212,7 +212,7 @@ async fn run_and_compare_query_with_auto_schemas( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +----------+----------+ | column_1 | column_2 | +----------+----------+ @@ -220,7 +220,7 @@ async fn run_and_compare_query_with_auto_schemas( | jorge | 200 | | andy | 150 | +----------+----------+ - "###); + "); }); Ok(()) @@ -433,21 +433,21 @@ impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - if let LogicalPlan::Extension(Extension { node }) = &plan { - if let Some(prev) = node.as_any().downcast_ref::() { - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: prev.k, - input: prev.input.clone(), - expr: prev.expr.clone(), - // In a real use case, this rewriter could have change the number of inputs, etc - invariant_mock: Some(InvariantMock { - should_fail_invariant: true, - kind: InvariantLevel::Always, - }), + if let LogicalPlan::Extension(Extension { node }) = &plan + && let Some(prev) = node.as_any().downcast_ref::() + { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, }), - }))); - } + }), + }))); }; Ok(Transformed::no(plan)) @@ -515,23 +515,18 @@ impl OptimizerRule for TopKOptimizerRule { return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = limit.input.as_ref() + if let LogicalPlan::Sort(Sort { expr, input, .. }) = limit.input.as_ref() + && expr.len() == 1 { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - invariant_mock: self.invariant_mock.clone(), - }), - }))); - } + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + invariant_mock: self.invariant_mock.clone(), + }), + }))); } Ok(Transformed::no(plan)) @@ -585,9 +580,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { kind, }) = self.invariant_mock.clone() { - if should_fail_invariant && check == kind { - return internal_err!("node fails check, such as improper inputs"); - } + assert_or_internal_err!( + !(should_fail_invariant && check == kind), + "node fails check, such as improper inputs" + ); } Ok(()) } @@ -733,9 +729,11 @@ impl ExecutionPlan for TopKExec { partition: usize, context: Arc, ) -> Result { - if 0 != partition { - return internal_err!("TopKExec invalid partition {partition}"); - } + assert_eq_or_internal_err!( + partition, + 0, + "TopKExec invalid partition {partition}" + ); Ok(Box::pin(TopKReader { input: self.input.execute(partition, context)?, diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 3ca8f846aa5e5..b86cd94a8a9b7 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -20,11 +20,11 @@ use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ - builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, - Int32Array, RecordBatch, StringArray, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, + builder::BooleanBuilder, cast::AsArray, }; +use arrow::array::{Int8Array, UInt64Array, as_string_array, create_array, record_batch}; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; @@ -38,15 +38,17 @@ use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, - exec_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, assert_batches_eq, + assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, + not_impl_err, plan_err, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, - LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, + OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, lit_with_metadata, }; +use datafusion_expr_common::signature::TypeSignature; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; use regex::Regex; @@ -63,13 +65,13 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c11)) | +------------------------------------------+ | 0.6584408483418835 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -82,13 +84,13 @@ async fn csv_query_avg_sqrt() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c12)) | +------------------------------------------+ | 0.6706002946036459 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -153,7 +155,7 @@ async fn scalar_udf() -> Result<()> { let result = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+-----+-----------------+ | a | b | my_add(t.a,t.b) | +-----+-----+-----------------+ @@ -162,7 +164,7 @@ async fn scalar_udf() -> Result<()> { | 10 | 12 | 22 | | 100 | 120 | 220 | +-----+-----+-----------------+ - "###); + "); let batch = &result[0]; let a = as_int32_array(batch.column(0))?; @@ -279,7 +281,7 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_udf(ScalarUDF::from(get_100_udf)); let result = plan_and_collect(&ctx, "select get_100() a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ @@ -288,22 +290,22 @@ async fn scalar_udf_zero_params() -> Result<()> { | 100 | | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() a").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() from t where a=999").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" ++ ++ - "###); + "); Ok(()) } @@ -330,13 +332,13 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { // Make sure that the UDF is used instead of the built-in function let result = plan_and_collect(&ctx, "select abs(a) a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); Ok(()) } @@ -425,20 +427,21 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_func\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_func\'") + ); // Can call it if you put quotes let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +--------------+ | MY_FUNC(t.i) | +--------------+ | 1 | +--------------+ - "###); + "); Ok(()) } @@ -469,13 +472,13 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { ctx.register_udf(udf); let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; insta::assert_snapshot!(batches_to_string(&alias_result), @r" @@ -945,6 +948,7 @@ struct ScalarFunctionWrapper { expr: Expr, signature: Signature, return_type: DataType, + defaults: Vec>, } impl ScalarUDFImpl for ScalarFunctionWrapper { @@ -973,7 +977,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { args: Vec, _info: &dyn SimplifyInfo, ) -> Result { - let replacement = Self::replacement(&self.expr, &args)?; + let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; Ok(ExprSimplifyResult::Simplified(replacement)) } @@ -981,7 +985,11 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { impl ScalarFunctionWrapper { // replaces placeholders with actual arguments - fn replacement(expr: &Expr, args: &[Expr]) -> Result { + fn replacement( + expr: &Expr, + args: &[Expr], + defaults: &[Option], + ) -> Result { let result = expr.clone().transform(|e| { let r = match e { Expr::Placeholder(placeholder) => { @@ -989,11 +997,19 @@ impl ScalarFunctionWrapper { Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { Transformed::yes(args[placeholder_position].clone()) - } else { + } else if placeholder_position >= defaults.len() { exec_err!( - "Function argument {} not provided, argument missing!", + "Invalid placeholder, out of range: {}", placeholder.id )? + } else { + match defaults[placeholder_position] { + Some(ref default) => Transformed::yes(default.clone()), + None => exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )?, + } } } _ => Transformed::no(e), @@ -1021,6 +1037,32 @@ impl TryFrom for ScalarFunctionWrapper { type Error = DataFusionError; fn try_from(definition: CreateFunction) -> std::result::Result { + let args = definition.args.unwrap_or_default(); + let defaults: Vec> = + args.iter().map(|a| a.default_expr.clone()).collect(); + let signature: Signature = match defaults.iter().position(|v| v.is_some()) { + Some(pos) => { + let mut type_signatures: Vec = vec![]; + // Generate all valid signatures + for n in pos..defaults.len() + 1 { + if n == 0 { + type_signatures.push(TypeSignature::Nullary) + } else { + type_signatures.push(TypeSignature::Exact( + args.iter().take(n).map(|a| a.data_type.clone()).collect(), + )) + } + } + Signature::one_of( + type_signatures, + definition.params.behavior.unwrap_or(Volatility::Volatile), + ) + } + None => Signature::exact( + args.iter().map(|a| a.data_type.clone()).collect(), + definition.params.behavior.unwrap_or(Volatility::Volatile), + ), + }; Ok(Self { name: definition.name, expr: definition @@ -1030,15 +1072,8 @@ impl TryFrom for ScalarFunctionWrapper { return_type: definition .return_type .expect("Return type has to be defined!"), - signature: Signature::exact( - definition - .args - .unwrap_or_default() - .into_iter() - .map(|a| a.data_type) - .collect(), - definition.params.behavior.unwrap_or(Volatility::Volatile), - ), + signature, + defaults, }) } } @@ -1061,10 +1096,11 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { // Create the `better_add` function dynamically via CREATE FUNCTION statement assert!(ctx.sql(sql).await.is_ok()); // try to `drop function` when sql options have allow ddl disabled - assert!(ctx - .sql_with_options("drop function better_add", options) - .await - .is_err()); + assert!( + ctx.sql_with_options("drop function better_add", options) + .await + .is_err() + ); let result = ctx .sql("select better_add(2.0, 2.0)") @@ -1109,6 +1145,180 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { "#; assert!(ctx.sql(bad_definition_sql).await.is_err()); + // FIXME: Definitions with invalid placeholders are allowed, fail at runtime + let bad_expression_sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE) + RETURNS DOUBLE + RETURN $1 + $3 + "#; + assert!(ctx.sql(bad_expression_sql).await.is_ok()); + + let err = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await + .expect_err("unknown placeholder"); + let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3"; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + // cannot mix named and positional style + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $1 + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("cannot mix named and positional style"); + let expected = "Error during planning: All function arguments must use either named or positional style."; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + // Check all function arity supported + let result = ctx.sql("select better_add()").await?.collect().await?; + + assert_batches_eq!( + &[ + "+--------------+", + "| better_add() |", + "+--------------+", + "| 4.0 |", + "+--------------+", + ], + &result + ); + + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err()); + assert!(ctx.sql("drop function better_add").await.is_ok()); + + // works with positional style + let sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + assert!(ctx.sql(sql).await.is_ok()); + + assert!(ctx.sql("select better_add()").await.is_err()); + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + // non-default argument cannot follow default argument + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("non-default argument cannot follow default argument"); + let expected = + "Error during planning: Non-default arguments cannot follow default arguments."; + assert!(expected.starts_with(&err.strip_backtrace())); + + // FIXME: The `DEFAULT` syntax does not work with positional params + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("sqlparser error"); + let expected = + "SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")"; + assert!(expected.starts_with(&err.strip_backtrace())); Ok(()) } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 2c6611f382cea..8be8609c62480 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -21,17 +21,17 @@ use std::path::Path; use std::sync::Arc; use arrow::array::Int64Array; -use arrow::csv::reader::Format; use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::test_util::batches_to_string; -use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::TaskContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::SessionContext; use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; @@ -55,7 +55,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ @@ -65,7 +65,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d | | 5 | ETHIOPIA | 0 | ven packages wake quickly. regu | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ - "###); + "); // just run, return all rows let rbs = ctx @@ -74,7 +74,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ @@ -89,7 +89,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull | | 10 | IRAN | 4 | efully alongside of the slyly final dependencies. | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ - "###); + "); Ok(()) } @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { + Expr::Literal(ScalarValue::Utf8(Some(path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 33607ebc0d2cc..57baf271c5913 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,8 +19,8 @@ //! user defined window functions use arrow::array::{ - record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, - UInt64Array, + Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, UInt64Array, + record_batch, }; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::FieldRef; @@ -38,8 +38,8 @@ use datafusion_functions_window_common::{ expr::ExpressionArgs, field::WindowUDFFieldArgs, }; use datafusion_physical_expr::{ - expressions::{col, lit}, PhysicalExpr, + expressions::{col, lit}, }; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -47,8 +47,8 @@ use std::{ any::Any, ops::Range, sync::{ - atomic::{AtomicUsize, Ordering}, Arc, + atomic::{AtomicUsize, Ordering}, }, }; @@ -62,8 +62,7 @@ const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \ from t ORDER BY x, y"; /// A query with a window function evaluated over a moving window -const BOUNDED_WINDOW_QUERY: &str = - "SELECT x, y, val, \ +const BOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ odd_counter(val) OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ from t ORDER BY x, y"; @@ -75,22 +74,22 @@ async fn test_setup() { let sql = "SELECT * from t order by x, y"; let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+ - | x | y | val | - +---+---+-----+ - | 1 | a | 0 | - | 1 | b | 1 | - | 1 | c | 2 | - | 2 | d | 3 | - | 2 | e | 4 | - | 2 | f | 5 | - | 2 | g | 6 | - | 2 | h | 6 | - | 2 | i | 6 | - | 2 | j | 6 | - +---+---+-----+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+ + | x | y | val | + +---+---+-----+ + | 1 | a | 0 | + | 1 | b | 1 | + | 1 | c | 2 | + | 2 | d | 3 | + | 2 | e | 4 | + | 2 | f | 5 | + | 2 | g | 6 | + | 2 | h | 6 | + | 2 | i | 6 | + | 2 | j | 6 | + +---+---+-----+ + "); } /// Basic user defined window function @@ -101,22 +100,22 @@ async fn test_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); // evaluated on two distinct batches assert_eq!(test_state.evaluate_all_called(), 2); @@ -175,22 +174,22 @@ async fn test_udwf_bounded_window_ignores_frame() { // Since the UDWF doesn't say it needs the window frame, the frame is ignored let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // evaluated on 2 distinct batches (when x=1 and x=2) assert_eq!(test_state.evaluate_called(), 0); @@ -205,22 +204,22 @@ async fn test_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -237,22 +236,22 @@ async fn test_stateful_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 0 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 1 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 0 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 1 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 10); assert_eq!(test_state.evaluate_all_called(), 0); @@ -268,22 +267,22 @@ async fn test_stateful_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate and update_state is called for each input row assert_eq!(test_state.evaluate_called(), 10); @@ -298,22 +297,22 @@ async fn test_udwf_query_include_rank() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -329,22 +328,22 @@ async fn test_udwf_bounded_query_include_rank() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -362,22 +361,22 @@ async fn test_udwf_bounded_window_returns_null() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | | - | 2 | i | 6 | | - | 2 | j | 6 | | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | | + | 2 | i | 6 | | + | 2 | j | 6 | | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -616,7 +615,9 @@ impl PartitionEvaluator for OddCounter { ranks_in_partition: &[Range], ) -> Result { self.test_state.inc_evaluate_all_with_rank_called(); - println!("evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}"); + println!( + "evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}" + ); // when evaluating with ranks, just return the inverse rank instead let array: Int64Array = ranks_in_partition .iter() diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 3b85640804219..9997d23d4c61f 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -20,30 +20,31 @@ //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) use std::any::Any; -use std::borrow::Cow; use std::collections::HashMap; use std::fmt::{self, Debug}; +use std::io::{Seek, SeekFrom}; use std::sync::Arc; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::ipc::convert::fb_to_schema; -use arrow::ipc::reader::FileReader; +use arrow::ipc::reader::{FileReader, StreamReader}; use arrow::ipc::writer::IpcWriteOptions; -use arrow::ipc::{root_as_message, CompressionType}; +use arrow::ipc::{CompressionType, root_as_message}; use datafusion_common::error::Result; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics, - DEFAULT_ARROW_EXTENSION, + DEFAULT_ARROW_EXTENSION, DataFusionError, GetExt, Statistics, + internal_datafusion_err, not_impl_err, }; use datafusion_common_runtime::{JoinSet, SpawnedTask}; +use datafusion_datasource::TableSchema; use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_datasource::write::{ - get_writer_schema, ObjectWriterBuilder, SharedBuffer, + ObjectWriterBuilder, SharedBuffer, get_writer_schema, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; @@ -59,9 +60,11 @@ use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use futures::stream::BoxStream; use futures::StreamExt; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use futures::stream::BoxStream; +use object_store::{ + GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, path::Path, +}; use tokio::io::AsyncWriteExt; /// Initial writing buffer size. Note this is just a size hint for efficiency. It @@ -71,8 +74,8 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +/// Factory struct used to create [`ArrowFormat`] #[derive(Default, Debug)] -/// Factory struct used to create [ArrowFormat] pub struct ArrowFormatFactory; impl ArrowFormatFactory { @@ -107,7 +110,7 @@ impl GetExt for ArrowFormatFactory { } } -/// Arrow `FileFormat` implementation. +/// Arrow [`FileFormat`] implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -150,12 +153,25 @@ impl FileFormat for ArrowFormat { let schema = match r.payload { #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { - let reader = FileReader::try_new(&mut file, None)?; - reader.schema() - } - GetResultPayload::Stream(stream) => { - infer_schema_from_file_stream(stream).await? + match FileReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(file_error) => { + // not in the file format, but FileReader read some bytes + // while trying to parse the file and so we need to rewind + // it to the beginning of the file + file.seek(SeekFrom::Start(0))?; + match StreamReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(stream_error) => { + return Err(internal_datafusion_err!( + "Failed to parse Arrow file as either file format or stream format. File format error: {file_error}. Stream format error: {stream_error}" + )); + } + } + } + } } + GetResultPayload::Stream(stream) => infer_stream_schema(stream).await?, }; schemas.push(schema.as_ref().clone()); } @@ -175,10 +191,40 @@ impl FileFormat for ArrowFormat { async fn create_physical_plan( &self, - _state: &dyn Session, + state: &dyn Session, conf: FileScanConfig, ) -> Result> { - let source = Arc::new(ArrowSource::default()); + let object_store = state.runtime_env().object_store(&conf.object_store_url)?; + let object_location = &conf + .file_groups + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .files() + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .object_meta + .location; + + let table_schema = TableSchema::new( + Arc::clone(conf.file_schema()), + conf.table_partition_cols().clone(), + ); + + let mut source: Arc = + match is_object_in_arrow_ipc_file_format(object_store, object_location).await + { + Ok(true) => Arc::new(ArrowSource::new_file_source(table_schema)), + Ok(false) => Arc::new(ArrowSource::new_stream_file_source(table_schema)), + Err(e) => Err(e)?, + }; + + // Preserve projection from the original file source + if let Some(projection) = conf.file_source.projection() + && let Some(new_source) = source.try_pushdown_projection(projection)? + { + source = new_source; + } + let config = FileScanConfigBuilder::from(conf) .with_source(source) .build(); @@ -202,12 +248,12 @@ impl FileFormat for ArrowFormat { Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } - fn file_source(&self) -> Arc { - Arc::new(ArrowSource::default()) + fn file_source(&self, table_schema: TableSchema) -> Arc { + Arc::new(ArrowSource::new_file_source(table_schema)) } } -/// Implements [`FileSink`] for writing to arrow_ipc files +/// Implements [`FileSink`] for Arrow IPC files struct ArrowFileSink { config: FileSinkConfig, } @@ -344,101 +390,167 @@ impl DataSink for ArrowFileSink { } } +// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +// See + const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; -/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. -/// See -async fn infer_schema_from_file_stream( +async fn infer_stream_schema( mut stream: BoxStream<'static, object_store::Result>, ) -> Result { - // Expected format: - // - 6 bytes - // - 2 bytes - // - 4 bytes, not present below v0.15.0 - // - 4 bytes - // - // - - // So in first read we need at least all known sized sections, - // which is 6 + 2 + 4 + 4 = 16 bytes. - let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; - - // Files should start with these magic bytes - if bytes[0..6] != ARROW_MAGIC { - return Err(ArrowError::ParseError( - "Arrow file does not contain correct header".to_string(), - ))?; - } - - // Since continuation marker bytes added in later versions - let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { - (&bytes[12..16], 16) + // IPC streaming format. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + // + // + // + // ... + // + // + // ... + // + // ... + // + // ... + // + // + + // The streaming format is made up of a sequence of encapsulated messages. + // See https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format + // + // (added in v0.15.0) + // + // + // + // + // + // The first message is the schema. + + // IPC file format is a wrapper around the streaming format with indexing information. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format + // + // + // + // + //