Skip to content

Remove TensorFlow dependency from multimodal image preprocessing #547

@KumarADITHYA123

Description

@KumarADITHYA123

Current Behavior
Right now in
gemma/multimodal/image.py
, the code uses TensorFlow's tf.image.decode_jpeg to handle JPEG decoding. This means anyone using the multimodal features has to install the full TensorFlow package, which is pretty heavy (several hundred MB) just for that one image decoding function.

There's actually a TODO comment at line 73 that says: # TODO(eyvinec): we should remove tf dependency.

Expected Behavior
For a library that's built on JAX and Flax, it would make way more sense to use a lighter image library like PIL (Pillow) for the JPEG decoding. This would let people work with multimodal models without needing to install TensorFlow at all.

Why This Matters
Smaller install size: Removing TensorFlow as a dependency would make the install much lighter
Fewer conflicts: TensorFlow can sometimes have CUDA version conflicts with other libraries
Better dev experience: People who just want to use Gemma with JAX shouldn't need to install another whole ML framework
What I'm Thinking
Replace the tf.image.decode_jpeg call with PIL's Image.open and convert the result to a JAX array. The output format would stay exactly the same (H, W, 3 shape with the right dtype), so this shouldn't break anything for existing users.

I'd be happy to put together a PR for this if it sounds good. I'm currently working on a GSoC project proposal around Gemma and came across this while setting up my environment.

Environment
Gemma: main branch
Python: 3.10+
Let me know what you think!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions