diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index e966c92d..73aafac0 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -17,7 +17,10 @@ load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pyt package( default_applicable_licenses = ["//:license"], - default_visibility = ["//jax_tpu_embedding/sparsecore:__subpackages__"], + default_visibility = [ + "//jax_tpu_embedding/sparsecore:__subpackages__", + "//smartass/brain/configure/jax:__subpackages__", + ], ) cc_library(