Как распознавать объекты 600 классов, используя 9 миллионов изображений из Open Images

AI

Если вы собираетесь создать классификатор изображений и вам нужна база для обучения, то вам понадобится лишь Google Open Images.

Этот датасет состоит более чем из 30 миллионов изображений и 15 миллионов ограничительных рамок. Это 18 терабайтов изображений!

К тому же Open Images является самым доступным из всех других датасетов изображений такого масштаба. К примеру, довольно известный ImageNet требует лицензию.

Однако работа с этими данными в обычных условиях не такая уж и простая. Нужно скачивать и обрабатывать большое количество метаданных и выделять место для их хранения (или запросить разрешение на получение бакета в Google Cloud).

С другой стороны, в открытом доступе довольно мало готовых датасетов для обучения, потому что процесс их создания и распространения достаточно сложный.

Мы построим и распространим несложный end-to-end пайплайн машинного обучения, используя Open Images, и создадим собственный датасет для любой из 600 меток классов, имеющихся в Open Images.

Мы продемонстрируем наше творение при помощи создания “open sandwiches”. Это будет простой воспроизводимый классификатор, который ответит на вопрос: является ли гамбургер сэндвичем?

Код находится на GitHub.

Скачивание данных

Для начала нам нужно скачать все необходимые данные.

Это одна из главных сложностей при работе с Google Open Images (да и с любым другим внешним датасетом). Так как существующих простых способов скачать эти данные нет, нам придётся писать скрипт, который сделает работу за нас.

Я написал скрипт на Python, который ищет в метаданных изображений датасета Open Images ключевые слова. В случае их обнаружения он ищет ссылки на оригиналы соответствующих изображений на Flickr, а затем скачивает их.

Ещё одно доказательство мощи Python: всё это можно сделать всего в 50 строчек кода:

import sys
import os
import pandas as pd
import requests

from tqdm import tqdm
import ratelim
from checkpoints import checkpoints
checkpoints.enable()

def download(categories):
    # Загрузка метаданных
    kwargs = {'header': None, 'names': ['LabelID', 'LabelName']}
    orig_url = "https://storage.googleapis.com/openimages/2018_04/class-descriptions-boxable.csv"
    class_names = pd.read_csv(orig_url, **kwargs)
    orig_url = "https://storage.googleapis.com/openimages/2018_04/train/train-annotations-bbox.csv"
    train_boxed = pd.read_csv(orig_url)
    orig_url = "https://storage.googleapis.com/openimages/2018_04/train/train-images-boxable-with-rotation.csv"
    image_ids = pd.read_csv(orig_url)

    # Получение ID категорий для данных категорий и выбор тех изображений, которые содержат рамки-ограничители для объектов этих категорий
    label_map = dict(class_names.set_index('LabelName').loc[categories, 'LabelID']
                     .to_frame().reset_index().set_index('LabelID')['LabelName'])
    label_values = set(label_map.keys())
    relevant_training_images = train_boxed[train_boxed.LabelName.isin(label_values)]

    # Начать с наиболее релевантных результатов, если они есть, иначе начать с нуля
    relevant_flickr_urls = (relevant_training_images.set_index('ImageID')
                            .join(image_ids.set_index('ImageID'))
                            .loc[:, 'OriginalURL'])
    relevant_flickr_img_metadata = (relevant_training_images.set_index('ImageID').loc[relevant_flickr_urls.index]
                                    .pipe(lambda df: df.assign(LabelValue=df.LabelName.map(lambda v: label_map[v]))))
    remaining_todo = len(relevant_flickr_urls) if checkpoints.results is None else
        len(relevant_flickr_urls) - len(checkpoints.results)

    # Загрузка изображений
    with tqdm(total=remaining_todo) as progress_bar:
        relevant_image_requests = relevant_flickr_urls.safe_map(lambda url: _download_image(url, progress_bar))
        progress_bar.close()

    # Запись изображений в файлы и добавление их в пакет
    if not os.path.isdir("temp/"):
        os.mkdir("temp/")
    for ((_, r), (_, url), (_, meta)) in zip(relevant_image_requests.iteritems(), relevant_flickr_urls.iteritems(),
                                             relevant_flickr_img_metadata.iterrows()):
        image_name = url.split("/")[-1]
        image_label = meta['LabelValue']

        _write_image_file(r, image_name)

@ratelim.patient(5, 5)
def _download_image(url, pbar):
    """Скачивание изображения по ссылке с частотой не более 1 изображения в секунду"""
    r = requests.get(url)
    r.raise_for_status()
    pbar.update(1)
    return r

def _write_image_file(r, image_name):
    """Загрузка изображения в файл"""
    filename = f"temp/{image_name}"
    with open(filename, "wb") as f:
        f.write(r.content)        
        
if __name__ == '__main__':
  categories = sys.argv[1:]
  download(categories)

Этот скрипт позволяет скачать ту часть изображений, которая содержит выделенные объекты, относящиеся к какой-либо из выбранной нами категорий:

$ git clone https://github.com/quiltdata/open-images.git
$ cd open-images/
$ conda env create -f environment.yml
$ source activate quilt-open-images-dev
$ cd src/openimager/
$ python openimager.py "Sandwiches" "Hamburgers"

Категории расположены в иерархическом порядке, то есть sandwich и hamburger — подклассы food (при этом hamburger не является подклассом sandwich).

Мы можем визуализировать объектную модель при помощи радиальной карты в Vega:

Интерактивную карту с аннотациями и её код можно посмотреть здесь.

Не у каждой категории в Open Images есть картинки с выделенными объектами этого класса. Однако этот скрипт позволит скачать датасет с объектами любого из 600 классов, для которых есть ограничительные рамки объектов на изображениях. Выбирайте любой:

football, toy, bird, cat, vase, hair dryer, kangaroo, knife, briefcase, pencil case, tennis ball, nail, high heels, sushi, skyscraper, tree, truck, violin, wine, wheel, whale, pizza cutter, bread, helicopter, lemon, dog, elephant, shark, flower, furniture, airplane, spoon, bench, swan, peanut, camera, flute, helmet, pomegranate, crown

Для этой статьи нам понадобятся всего два: hamburger иsandwich.

Очищай и обрезай

После того, как мы запустили скрипт и загрузили наши изображения в определённую директорию, мы можем просмотреть их при помощи matplotlib:

import matplotlib.pyplot as plt
from matplotlib.image import imread
%matplotlib inline
import os

fig, axarr = plt.subplots(1, 5, figsize=(24, 4))
for i, img in enumerate(os.listdir('../data/images/')[:5]):
  axarr[i].imshow(imread('../data/images/' + img))
Пять примеров изображений с сэндвичами и гамбургерами

Эти изображения не совсем пригодны для обучения. Они сочетают в себе все неудобства датасетов с изображениями из открытых источников в интернете.

Даже на этой небольшой части датасета видно, что у всех изображений отличаются ширина и высота, а также объекты классов видны лишь частично.

В данном случае у нас даже не загрузилось само изображение. Вместо него мы скачали заполнитель, обозначающий, что изображение было удалено.

После загрузки мы получаем несколько тысяч изображений вроде этих. Следующим шагом будет обрезка изображения по ограничительным рамкам для того, чтобы оставить в нём лишь часть с гамбургерами и сэндвичами.

Вот другой набор изображений, на этот раз с видимыми ограничительными рамками.

Заметьте, что датасет содержит как иллюстрации из меню, так и изображения с несколькими объектами

Эту работу выполняет блокнот Jupyter из демо-репозитория GitHub.

Также нужно изменить метаданные изображений так, чтобы они соответствовали выходным изображениям, и избавиться от удалённых изображений.

После выполнения кода на диске появится папка с названием images_cropped, которая хранит обрезанные изображения.

Построение модели

После того, как мы загрузили и почистили данные, можно приступить к обучению модели.

Мы будем обучать свёрточную нейронную сеть (CNN).

Свёрточная нейронная сеть — особый тип нейронной сети, которая шаг за шагом формирует признаки, исходя из пикселей в изображениях. Затем в изображении выявляются эти признаки и на их основе генерируется результат классификации.

Эта архитектура использует преимущества локальности. То есть она основана на предположении, что пиксель имеет гораздо больше общего с рядом находящимися пикселями, чем с теми, что находятся дальше.

CNN также обладает другими привлекательными для нас свойствами: устойчивостью к шумам и масштабной инвариантностью (в определённой степени). Это улучшает качество классификации.

Если вы не знаете, как работают CNN, то можете посмотреть видео Брэндона Рорера о нейронных сетях.

Обучив простую свёрточную нейронную сеть, мы видим, что даже она выдаёт достаточно хороший результат. Для определения и обучения модели я использую Keras.

Для начала мы отсортируем изображения по директориям:

images_cropped/
    sandwich/
        some_image.jpg
        some_other_image.jpg
        ...
    hamburger/
        yet_another_image.jpg
        ...

Затем установим указатель Keras на эту папку, используя следующий код:

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    rescale=1/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(
    rescale=1/255
)

train_generator = train_datagen.flow_from_directory(
    '../data/images_cropped/quilt/open_images/',
    target_size=(128, 128),
    batch_size=16,
    class_mode='binary'
)

validation_generator = test_datagen.flow_from_directory(
    '../data/images_cropped/quilt/open_images/',
    target_size=(128, 128),
    batch_size=16,
    class_mode='binary'
)

Keras исследует папку с входными данными и определит, что в данной задаче нужно классифицировать объекты по двум классам. Библиотека установит имена классов, основываясь на названиях дочерних папок, и создаст “генераторы изображений” из файлов этих папок.

Мы не просто возвращаем изображения. Мы возвращаем случайно увеличенные, скошенные и обрезанные части изображений (при помощи train_datagen.flow_from_directory).

Это демонстрация аугментации данных в действии.

Аугментация данных — практика, в которой классификатору для обучения даются случайно обрезанные и искажённые версии входных данных. Это помогает решить проблему маленького размера датасета. То есть мы можем обучать нашу модель на одном и том же изображении несколько раз, внося в него незначительные изменения.

После того, как мы определили входные данные, нужно определить саму модель:

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.losses import binary_crossentropy
from keras.callbacks import EarlyStopping
from keras.optimizers import RMSprop


model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), input_shape=(128, 128, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())  # преобразовывает наши трёхмерные признаки в одномерные
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss=binary_crossentropy,
              optimizer=RMSprop(lr=0.0005),  # половина значения lr по умолчанию
              metrics=['accuracy'])

Это модель простой свёрточной нейронной сети. Она содержит три свёрточных слоя: единственный плотно связанный слой постпроцессинга до выходного слоя, сильная регуляризация с помощью исключения и активация relu.

Эти слои работают вместе как единое целое, минимизируя переобучение, что особенно важно из-за маленького размера датасета.

Последний шаг — непосредственно обучение модели.

import pathlib

sample_size = len(list(pathlib.Path('../data/images_cropped/').rglob('./*')))
batch_size = 16

hist = model.fit_generator( 
    train_generator, 
    steps_per_epoch=sample_size // batch_size, 
    epochs=50,
    validation_data=validation_generator, 
    validation_steps=round(sample_size * 0.2) // batch_size, 
    callbacks=[EarlyStopping(monitor='val_loss', min_delta=0, patience=4)]
) 

model.save("clf.h5")

Этот код определяет размер шага эпохи на основе количества обучающих изображений и размера батча (16 в нашем случае). Затем он обучает модель в течении 50 эпох.

Скорее всего, выполнение кода прервётся раньше из-за функции EarlyStopping. Она возвращает наиболее хорошо обученную модель ранее, чем пройдёт 50 эпох, если увидит, что качество после прохождения текущей эпохи незначительно отличается от качества модели на предыдущих четырех эпохах.

Мы выбрали такое большое значение для ожидания, так как потеря валидаций значительно варьируется.

Точность модели, обученной таким способом, примерно 75%:

precision    recall  f1-score   support

           0       0.90      0.59      0.71      1399
           1       0.64      0.92      0.75      1109

   micro avg       0.73      0.73      0.73      2508
   macro avg       0.77      0.75      0.73      2508
weighted avg       0.78      0.73      0.73      2508

Стоит заметить, что наша модель слишком неуверенная при классификации гамбургеров (класс с меткой 0) и, наоборот, чересчур уверенная при классификации сэндвичей (класс с меткой 1).

Таким образом, на 90% изображений, помеченных как “гамбургеры”, действительно изображены гамбургеры, но правильно классифицировано было только 59%.

С другой стороны, лишь 64% изображений, помеченных моделью как “сэндвичи”, действительно содержат сэндвичи. Но верно классифицировано было 92% всех изображений с сэндвичами.

Результат довольно близок к точности в 80%, достигнутой Франсуа Шолле при применении схожей модели на датасете почти такого же размера в решении классической проблемы классификации котов и собак.

Разница, скорее всего, объясняется большим количеством помех и шумов в датасете Google Open Images V4.

Помимо фотографий, датасет включает в себя иллюстрации. Эти художественные свободы усложняют задачу классификации. Поэтому при создании модели, возможно, будет целесообразным убрать эти изображения.

Распространение модели

Теперь, когда мы собрали датасет и обучили модель, было бы обидно не поделиться нашей работой с другими людьми.

Проекты, использующие машинное обучение, должны быть воспроизводимыми. Для этого рекомендую придерживаться следующей стратегии:

  • Необходимо разделить зависимости на компоненты данных, кода и среды.
  • Версия зависимости данных контролирует описание модели и тренировочный датасет. Эти данные должны быть сохранены на хранилище BLOB-объектов, например, на Amazon S3 с Quilt T4.
  • Зависимости кода контролируют код, использованный для обучения модели (используйте git).
  • Зависимости среды контролируют среду обучения модели. В промышленной среде это, вероятнее всего, будет Docker-файл, но локально можно использовать pip или conda.
  • Для того, чтобы передать кому-то переобучаемую копию модели, отправьте ему соответствующий кортеж {данные, код, среда}.

Следование этим принципам позволит получить всё, что нужно, для обучения своей копии модели этим небольшим куском кода:

git clone https://github.com/quiltdata/open-images.git
conda env create -f open-images/environment.yml
source activate quilt-open-images-dev
python -c "import t4; t4.Package.install('quilt/open_images', dest='open-images/', registry='s3://quilt-example')"

Заключение

Итак, мы продемонстрировали пайплайн машинного обучения, рассмотрели все этапы, начиная с загрузки и преобразования датасета до обучения модели. Также мы поделились этой моделью таким образом, чтобы каждый мог её построить и обучить на своих данных.

Так как пользовательские датасеты сложны для создания и распространения, со временем появилась группа датасетов, которые применяются везде. И не потому, что они действительно качественные, а потому, что они освобождают нас от создания новых датасетов.

Например, недавно выпущенный курс машинного обучения от Google использует California Housing Dataset. Эти данные устарели примерно на 20 лет!

Давайте расширять горизонты и использовать пользовательские актуальные датасеты. Это гораздо легче, чем вы думаете!


Специально для сайта ITWORLD.UZ. Новость взята с сайта NOP::Nuances of programming