Skip to content

keras.ops.image.map_coordinates breaks with TF backend #21420

Closed
@tristan-deep

Description

@tristan-deep

Seems to be an issue with keras.ops.image.map_coordinates and TF backend. Seems to happen on both CPU and GPU.

Versions
TF: '2.19.0' Keras: '3.10.0' JAX: '0.5.2' PyTorch: '2.6.0+cu124'
import os

os.environ["KERAS_BACKEND"] = "tensorflow"  # works for "jax" or "torch"

import keras
import matplotlib.pyplot as plt
import numpy as np

# Flag to control out-of-bounds coordinates
# Breaks TF if set to True
use_out_of_bounds = True

n_rho, n_theta = 64, 64
polar_img = np.zeros((n_rho, n_theta), dtype=np.float32)

for i in range(n_rho):
    if 8 < i < 16 or 24 < i < 32 or 40 < i < 48:
        polar_img[i, :] = 1.0

if use_out_of_bounds:
    # intentionally out-of-bounds, which break TF
    rho_idx = np.linspace(-5, n_rho + 5, n_rho)
    theta_idx = np.linspace(-5, n_theta + 5, n_theta)
else:
    rho_idx = np.linspace(0, n_rho - 1, n_rho)
    theta_idx = np.linspace(0, n_theta - 1, n_theta)

rho_grid, theta_grid = np.meshgrid(rho_idx, theta_idx, indexing="ij")
coords = np.stack([rho_grid, theta_grid], axis=0).reshape(2, -1).astype(np.float32)
coords = keras.ops.convert_to_tensor(coords)

out = keras.ops.image.map_coordinates(
    polar_img, coordinates=coords, order=0, fill_mode="constant", fill_value=0.0
)
out = keras.ops.convert_to_numpy(out).reshape(n_rho, n_theta)

plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(polar_img, aspect="auto", cmap="gray")
plt.title("Input")
plt.subplot(1, 2, 2)
plt.imshow(out, aspect="auto", cmap="gray")
plt.title("map_coordinates Output")
plt.tight_layout()
plt.savefig("test.png")

Should produce (only works with torch and jax backend)

Image

But produces (with TF backend)

Image

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions