Tensorflow не может получить `image.shape` из метода в` dataset.map (mapFn) `

10

Я пытаюсь сделать tensorflowэквивалент torch.transforms.Resize(TRAIN_IMAGE_SIZE), который изменяет размер наименьшего размера изображения до TRAIN_IMAGE_SIZE. Что-то вроде этого

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

Простой ответ здесь: Tensorflow: Обрезать самую большую центральную квадратную область изображения

Но когда я использую метод с tf.data.Dataset.map(transforms), я получаю shape=(None,None,3)изнутри largest_sq_crop(image). Метод работает нормально, когда я его называю нормально.

Майкл
источник
1
Я считаю, что проблема связана с тем фактом, что EagerTensorsони недоступны внутри, Dataset.map()поэтому форма неизвестна. есть ли обходной путь?
Майкл
Можете ли вы включить определение largest_sq_crop?
Якуб

Ответы:

1

Я нашел ответ. Это было связано с тем, что мой метод изменения размера работал нормально с нетерпеливым выполнением, например, tf.executing_eagerly()==Trueно не работал при использовании внутри dataset.map(). Видимо, в этой среде исполнения tf.executing_eagerly()==False.

Моя ошибка заключалась в том, что я распаковывал форму изображения, чтобы получить размеры для масштабирования. Кажется, что выполнение графика Tensorflow не поддерживает доступ к tensor.shapeкортежу.

  # wrong
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # also wrong
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # but this works!!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

Я использовал размеры формы вниз по течению в моей dataset.map()функции, и она вызвала следующее исключение, потому что она получала Noneвместо значения.

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.

Когда я переключился на ручную распаковку формы tf.shape(), все работало нормально.

Майкл
источник