Orchideen Klassifikation

Ein erstes Anwendungsbeispiel, das die Orchideen-Beispiele von Dietmar Jakely mit Hilfe von maschinellen Mitteln automatisch zu erkennen und richtig zuzuordenen lernt.

Detailiertere Informationen zur Reproduktion dieser Resultate auf dem eigenen Rechner sind dem File README bzw. dem betreffenden GitLab Repository zu entnehmen.

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
import os
import torch, torchvision
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from scipy import misc
from sklearn.metrics import confusion_matrix, classification_report
from fastai.conv_learner import *
#from fastai.dataset import *
#from fastai.metrics import *
from fastai.plots import *

plt.rcParams['figure.figsize'] = (10,10)
In [3]:
# Anzeigen, ob Rechenunterstützung durch die Grafikkarte verfügbar ist
torch.cuda.is_available()
Out[3]:
True

Schnittstelle zum Einlesen der Orchideenauswahl

Da ich sehr viel Zeit darauf verwenden musste, mich mit ärgerliche Fehler der fastai-Library herumzuschlagen, die das Einlesen von Tabellen betreffen, hab ich schließlich einen eigenen Ersatz dafür geschrieben, um unsere Orchideendaten vernünftig nutzen zu können.:

In [4]:
class OrchideenData(ImageClassifierData):
    
    @classmethod
    def from_csv(cls, path, orch_dir, trn_csv , val_csv, tfms=(None,None), bs=64, num_workers=8):
        """ Auswahl der Bilder an Hand von vorbereiteten CSV-Listen

        Arguments:
            path: a root path of the data (for storing trained models, precomputed values, etc)
            orch_dir: Verzeichnis mit Orchideendaten
            trn_csv: CSV-File der Trainingsauswahl
            val_csv: CSV-File der Validierungsauswahl
            bs: batch size
            tfms: transformations (for data augmentations). e.g. output of `tfms_from_model`
            trn_name: a name of the folder that contains training images.
            val_name:  a name of the folder that contains validation images.
            num_workers: number of workers

        Returns:
            ImageClassifierData
        """
        trn,val = [cls.parse_csv(orch_dir, o) for o in (trn_csv, val_csv)]
        datasets = cls.get_ds(FilesIndexArrayDataset, trn, val, tfms, path=path, test=None)
        return cls(path, datasets, bs, num_workers, classes=trn[2])
    
    @staticmethod
    def parse_csv(orch_dir, cvs):
        with open(os.path.join(PATH, orch_dir, cvs)) as handle:
            fields = [line.strip().split(',') for line in handle][1:]
        fnames = [os.path.join(orch_dir, fname) for fname,label in fields]
        all_labels = list(set([label for fname,label in fields]))
        idxs = [all_labels.index(label) for fname,label in fields]
        label_arr = np.array(idxs, dtype=int)
        return fnames, label_arr, all_labels

Der eigentliche Lernprozess:

Der hier vorliegende Lösungsansatz bedient sich einer Technik, die als Transfer Learning bekannt ist. Man benutzt dazu fertige, für ihre diesbezügliche Eignung bekannte Modelle -- im konkreten Fall: ResNet152 --, die im Vorfeld bereits mit Millionen von Bildern aus dem ImageNet trainiert wurden, und ergänzt sie nur um kleine Adaptionen an die neue Problemstellung. Auch wenn diese vortrainierten Modelle noch nie mit den hier zu erlernenden Unterscheidungen konfrontiert waren, verfügen sie doch bereits über einen umfangreichen Bestand an bereits erworbenen Differenzierungsmustern, auf die sie beim Trainieren der neunen Aufgabenstellung zurückgreifen können.

In [ ]:
PATH='data/orchideen' # Arbeitsverzeichnis mit den genutzten Daten, generierte Modellen, etc.

sz=256 # 'image size' auf die Bilder bei der Verarbeitung herunterskaliert werden.
bs=4   # 'batch size' anzahl der gleichzeitig zu verarbeitenden Bilder

       # Beide Werte sind ausgesprochen nieder angesetzt, da die Berechnungen auf einer
       # relativ leistungsschwache Grafikkarte (GTX750Ti, 2GB) erfolgte. Ein Ãœberschreiten
       # der GPU-Speicherapazität zieht unweiglich Abstürze nach sich, die einen Neustart 
       # des Jupyter Kernels erfordern.
       #
       # Die verwendeten CNNs werden im Vorfeld mit mind. 224px Bildgröße vortrainiert, was 
       # als vernünftiger Anhaltspunkt für die zu verwendende Größe dienen sollte.
    
#arch=resnet34
arch=resnet152

!rm -rf ./{PATH}/tmp
data = OrchideenData.from_csv(PATH, orch_dir='orchideen',  
                              trn_csv='min50_training.csv', 
                              val_csv='min50_validation.csv',
                              tfms=tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.5))

learn = ConvLearner.pretrained(arch, data, precompute=True)

Nach der Vorbereitung des zu verwendenden neuronalen Netzwerks, wird in einer Schleife wiederholt der gesamte Umfang der zum Trainieren bestimmten Bilder genutzt, um das Modell Schritt für Schritt in seiner Unterscheidungs- und Zuordnungsleistung zu optimieren.

Der Rückgabewert der sgn. Loss function dient dabei als Orientierung. Darin drückt sich aus, wie weit es gelungen ist, sich dem erwünschten Ziel zu nähern -- in unserem Fall also, die Bilder möglicht korrekt zuzordnen. Je kleiner der entsprechende Wert ausfällt, umso besser.

Die lange Zahlenkolonne, die bei jedem vollen Durchlauf durch alle Bilder (=epoche) um eine Zeile erweitert wird, enthält zwei solche Loss-Angaben:

Die erste Spalte nach der Durchlaufsnummer zeigt den Loss-Wert bezogen auf die Gesamtheit der Trainings-Bilder.

Die zweite Spalte enthält einen ganz ähnlich eruierten Wert, der sich aber in diesem Fall auf eine Validierungsgruppe von Bilder bezieht -- in unserem Fall 20% der Gesamtheit, die bereits im Vorfeld zufällig aus jeder Orchideenart herausgegriffen wurden. Diese Validierungsbilder bekommt der eigentliche Lernprozess nie zu Gesicht. Sie dienen ausschließlich der Überprüfung, wie weit die Erkennung tatsächlich auch auf andere Bilder außerhalb der Lernmenge übertragbar ist.

Wichtig ist ist die Beziehung der Werteentwicklung in diesen beiden Spalten vor allem deshalb, um sgn. overfitting zu erkennen. Dabei handelt es sich um ein Phänomen, das sich bei zu intensivem Training einstellt. Dabei wird zwar weiter die Annäherung an den Trainingsdatensatz vorangetrieben -- der Loss-Wer in der ersten Spalte nimmt als Schritt für Schritt weiter ab --, aber die Generalisierbarkeit der repräsentierten Beschreibung nimmt zusehends ab. Es werden also nur mehr die tatsächlichen Lernbeispiele immer präziser erkannt, nicht aber fremde Daten, die man an Hand ihrer Ähnlichkeit bestimmen wollte -- deshalb steigt der Wert in der zweiten Spalte wieder an.

Die letzte Spalte der Ausgabe steht für die accuracy, was bereits als Prozentangabe für die Treffergenauigkeit der Bilder gelesen werden kann, wenn man das Komma um zwei Stellen nach rechts verschiebt.

Wichtig an dieser Art des Lernens bzw. der Optimierung von Modellen ist der Umstand, das dabei auch stochastische Momente im Spiel sind. Es wird also zumindest jener Ausgangspunkt zufällig gewählt, von dem aus man dann in unzähligen Versuchen auf eine besseren Lösungen hinzuarbeiten versucht. Die Enwicklung der Werte und das resultierende Modell nehmen also bei jedem erneuten Programmdurchlauf eine andere Form an.

Da in diesem ersten Trainingsschritt vorerst einmal nur jener kleine Teil trainiert wird, der zum bereits vortrainierte Model hinzugefügt wurde, geht das in diesem Fall ausgesprochen schnell. Die Ergebnisse wirken aber trotzdem bereits nach einigen Iterationen bzw. nach Sekunden relativ brauchbar. Darin zeigt sich der Vorzug des Transfer Learnings.

In [5]:
%time learn.fit(0.001, 100)
100%|██████████| 40/40 [01:25<00:00,  2.13s/it]
100%|██████████| 8/8 [00:17<00:00,  2.19s/it]
[0.      2.5065  1.68534 0.48607]                         
[1.      1.90732 1.28775 0.6151 ]                         
[2.      1.56498 1.10071 0.66016]                         
[3.      1.31294 1.00829 0.68633]                         
[4.      1.14712 0.93324 0.69232]                         
[5.      1.01354 0.88905 0.69232]                         
[6.      0.90292 0.84985 0.70807]                          
[7.      0.82832 0.8233  0.71211]                          
[8.      0.77817 0.80727 0.71016]                          
[9.      0.71774 0.77149 0.72174]                          
[10.       0.67557  0.75501  0.72409]                      
[11.       0.64143  0.75605  0.73581]                      
[12.       0.6081   0.73747  0.74375]                      
[13.       0.56598  0.72258  0.74388]                      
[14.       0.56517  0.71846  0.75169]                      
[15.       0.53267  0.71183  0.75169]                      
[16.       0.50914  0.70521  0.75963]                      
[17.       0.48546  0.68661  0.77148]                      
[18.       0.46863  0.68898  0.75768]                      
[19.       0.44812  0.68415  0.77357]                      
[20.       0.43131  0.69449  0.75768]                      
[21.       0.41765  0.675    0.77565]                      
[22.       0.40506  0.67262  0.7694 ]                      
[23.       0.3811   0.67793  0.76758]                      
[24.       0.37473  0.67115  0.76966]                      
[25.       0.36226  0.67329  0.7737 ]                      
[26.       0.35818  0.66688  0.77175]                      
[27.       0.35036  0.67039  0.77773]                      
[28.       0.33249  0.66922  0.77175]                      
[29.       0.32346  0.67433  0.77187]                      
[30.       0.32484  0.67039  0.76784]                      
[31.       0.30659  0.66391  0.76575]                      
[32.       0.29591  0.65469  0.77357]                      
[33.       0.28678  0.66046  0.7776 ]                      
[34.       0.28071  0.65566  0.77552]                      
[35.       0.27094  0.64568  0.77943]                      
[36.       0.26069  0.6534   0.77383]                      
[37.       0.26108  0.65455  0.7737 ]                      
[38.       0.25595  0.65237  0.77943]                      
[39.       0.24575  0.66826  0.77175]                      
[40.       0.24675  0.64711  0.77969]                      
[41.       0.24062  0.65506  0.7737 ]                      
[42.       0.24735  0.64829  0.77773]                      
[43.       0.23553  0.65526  0.76589]                      
[44.       0.22304  0.65454  0.77969]                      
[45.       0.21657  0.65794  0.77383]                      
[46.       0.23356  0.6467   0.77969]                      
[47.       0.21297  0.64734  0.77773]                      
[48.       0.20685  0.64698  0.77578]                      
[49.       0.2072   0.65067  0.77773]                      
[50.       0.203    0.6466   0.78359]                      
[51.       0.19557  0.64324  0.77969]                      
[52.       0.19299  0.65399  0.7875 ]                      
[53.       0.18559  0.64613  0.77773]                      
[54.       0.18795  0.65584  0.77578]                      
[55.       0.18084  0.659    0.77787]                      
[56.       0.17933  0.66387  0.77578]                      
[57.       0.1764   0.66722  0.77201]                      
[58.       0.16661  0.66499  0.77787]                      
[59.       0.16908  0.65999  0.77982]                      
[60.       0.16633  0.65795  0.78177]                      
[61.       0.16629  0.66325  0.78177]                      
[62.       0.155    0.66546  0.77982]                      
[63.       0.15075  0.66003  0.78359]                      
[64.       0.15206  0.65596  0.77578]                      
[65.       0.15785  0.6541   0.78555]                      
[66.       0.15369  0.66289  0.77969]                      
[67.       0.15994  0.65542  0.7776 ]                      
[68.       0.15221  0.66038  0.78359]                      
[69.       0.15282  0.66434  0.78945]                      
[70.       0.15069  0.65743  0.78359]                      
[71.       0.14619  0.67043  0.79349]                      
[72.       0.14566  0.66335  0.78568]                      
[73.       0.1419   0.65691  0.78958]                      
[74.       0.13645  0.6589   0.78177]                      
[75.       0.12895  0.66028  0.77982]                      
[76.       0.12469  0.66832  0.77591]                      
[77.       0.12437  0.66174  0.77591]                      
[78.       0.13171  0.66875  0.77383]                      
[79.       0.12837  0.65834  0.78177]                      
[80.       0.12722  0.65219  0.78359]                      
[81.       0.1201   0.6609   0.78164]                      
[82.       0.12187  0.66192  0.78177]                      
[83.       0.11541  0.67174  0.77396]                      
[84.       0.11297  0.6663   0.77773]                      
[85.       0.11027  0.67269  0.77578]                      
[86.       0.11089  0.66834  0.77591]                      
[87.       0.11239  0.671    0.77578]                      
[88.       0.11218  0.66953  0.76992]                      
[89.       0.10985  0.66209  0.77565]                      
[90.       0.11518  0.67395  0.77982]                      
[91.       0.10626  0.67528  0.77578]                      
[92.       0.10696  0.66961  0.77591]                      
[93.       0.11097  0.67646  0.77396]                      
[94.       0.10606  0.67019  0.78385]                      
[95.       0.10149  0.66935  0.78555]                      
[96.       0.09706  0.67136  0.78359]                       
[97.       0.09736  0.67411  0.77773]                       
[98.       0.09673  0.67095  0.77787]                       
[99.       0.09531  0.67506  0.78164]                       

CPU times: user 45.5 s, sys: 32.2 s, total: 1min 17s
Wall time: 47.7 s
In [6]:
learn.sched.plot_loss()

Im Anschluss oder auch in Verbindung mit diesem elementaren Trainingsvorgang können noch einige weitere sinnvolle Verfeinerungen vorgenommen werden, um die Erkennungsgenauigkeit ein weiter zu steigern.

Eine solche Methode besteht darin, das Netz nicht ausschließlich nur mit den unveränderten Lernbeispielen zu trainieren, sondern die betreffenden Bilder bewusst mehrfach einzulesen, und dabei jeweils in verschiedenen Winkeln zu rotieren, zu spiegeln oder skalierte Ausschnitte daraus zu wählen, um auch die damit verbundenen abgeleiteten Ansichten besser zu erkennen. Diese Technik bezeichnet man als Augmentierung.

In userem Fall hat das allerdings keinen nennenswerten Gewinn gebracht, weshalb es hier auskommentiert ist, um die Verarbeitung nicht sinnlos zu verlangsamen.

In [18]:
#learn.precompute=False
#%time learn.fit(0.001, 10)

Abschießend wird nun der Lernprozess bzw. die Optimierung der Parameter, die ja bisher nur den hinzugefügten Teil betroffen haben, einige Iterationen lang auf das gesamte Netz ausgedehnt. Damit wird der Rechenaufwand gravierend größer, aber doch auch wieder eine geringe Qualitätssteigerung in der Erkennungsrate bewirkt.

In [8]:
learn.unfreeze()
learn.bn_freeze(True)
%time learn.fit([1e-5,1e-4,1e-3], 8)
[0.      0.67639 0.68888 0.77187]                          
[1.      0.64899 0.68073 0.7737 ]                          
[2.      0.65994 0.66683 0.77747]                          
[3.      0.60465 0.66886 0.78138]                          
[4.      0.56981 0.67575 0.7776 ]                          
[5.      0.54843 0.66921 0.77982]                          
[6.      0.54309 0.66202 0.77591]                          
[7.      0.55333 0.6624  0.77396]                          

CPU times: user 14min 5s, sys: 6min 49s, total: 20min 54s
Wall time: 13min 59s
In [9]:
learn.sched.plot_loss()

Das fertig trainierte Netz kann abschließend in einer Datei abgespeichert werden, um es in Zukunft ohne langwierigen Trainingsaufwand für div. Anwendungen nutzen zu können.

In [10]:
learn.save('min50')

Auswertung der Daten

Als Beispiel wurden hier nur jene Arten herausgegriffen, von denen mehr als 50 Abbildungen verfügbar waren, da das Erlernenen an Hand von noch weniger Beispielen ganz spezielle Lösungsansätze erfordern würde.

Alle Angaben zur tatsächlichen maschinellen Klassifizierungsleistung beziehen sich auf jene Validierungsstichprobe (20% der verfügbaren Gesamtheit), mit der der Lernprozess vorher nicht in Berührung kam.

In [11]:
log_probs_tta,y = learn.TTA()
probs_tta = np.mean(np.exp(log_probs_tta),0)
preds_tta = np.argmax(probs_tta, axis=1)
                                             
In [5]:
# eine leichet erweiterte darstellung der confusion matrix
def plot_orchideen_confusion_matrix(cm, classes, y, order, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    (This function is copied from the scikit docs.)
    """
    count = np.unique(y, return_counts=True)[1][np.array(order)]

    plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, [f'{l} [{nr}]' for l,nr in zip(classes, count)])
    
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, f'{cm[i, j]}\n{cm[i,j]*100/count[i]:.3}%'if i==j else cm[i, j], 
                 horizontalalignment="center", verticalalignment="top",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('Korrekte Zuordnung [Anzahl an Überprüfungsbildern]')
    plt.xlabel('Ermittelte Zuordnung')
    
In [13]:
classes = np.array(data.classes)
order = classes.argsort()
labels = classes[order]

print(classification_report(y, preds_tta, order, labels, digits=3))
cm = confusion_matrix(y, preds_tta, order)
out = plot_orchideen_confusion_matrix(cm, labels, y, order)
                            precision    recall  f1-score   support

            Ophrys apifera      0.933     0.875     0.903        16
           Ophrys dinarica      0.867     0.929     0.897        14
           Ophrys exaltata      1.000     0.636     0.778        11
         Ophrys fusca s.l.      1.000     1.000     1.000        11
    Ophrys holoserica s.l.      0.774     0.714     0.743        91
       Ophrys holoserica-3      0.817     0.826     0.822        92
          Ophrys incubacea      0.885     0.793     0.836        29
         Ophrys istriensis      0.851     0.905     0.878        95
Ophrys passionis garganica      0.700     0.778     0.737        18
      Ophrys rhodostephane      0.765     0.867     0.812        15
     Ophrys sphegodes s.l.      0.851     0.941     0.894        85
         Ophrys tommasinii      0.765     0.684     0.722        19
       Ophrys zinsmeisteri      1.000     0.667     0.800        12

               avg / total      0.835     0.833     0.831       508

In [14]:
res = ImageModelResults(data.val_ds, np.mean(log_probs_tta, 0))

Am besten erkannte Beispiele:

In [15]:
for nr in order:
    print(f'Am besten erkannte Beispiele aus der Kategorie: "{classes[nr]}":')
    res.plot_by_correct(nr, True)
    plt.show()
Am besten erkannte Beispiele aus der Kategorie: "Ophrys apifera":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys dinarica":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys exaltata":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys fusca s.l.":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys holoserica s.l.":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys holoserica-3":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys incubacea":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys istriensis":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys passionis garganica":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys rhodostephane":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys sphegodes s.l.":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys tommasinii":
Am besten erkannte Beispiele aus der Kategorie: "Ophrys zinsmeisteri":

Nicht erkannte Beispiele:

In [16]:
for nr in order:
    print(f'Nicht erkannte Beispiele aus der Kategorie: "{classes[nr]}":')
    res.plot_by_correct(nr, False)
    plt.show()
Nicht erkannte Beispiele aus der Kategorie: "Ophrys apifera":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys dinarica":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys exaltata":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys fusca s.l.":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys holoserica s.l.":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys holoserica-3":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys incubacea":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys istriensis":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys passionis garganica":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys rhodostephane":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys sphegodes s.l.":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys tommasinii":
Nicht erkannte Beispiele aus der Kategorie: "Ophrys zinsmeisteri":

Vergleich mit ganz einfachen nicht-vortrainierten CNNs

Wie Eingangs erwähnt, habe ich bei der Umsetzung ein bereits vortrainiertes Modell von ausgesprochen hoher Komplexitität verwendet. CNNs von solcher Dichte sind überhaupt erst seit zwei-drei Jahren umsetzbar und haben die Bilderkennung massiv verbessert. (Eine recht verständliche Einführung in die damit verbundenen Probleme und revolutionierenden Fortschritte, gibt der Vortrag: "Really Deep Neural Networks with PyTorch" von David Dao auf der PyCon 2017)

Derartiges zu nutzen ist heute mit den richtigen Werkzeugen fast einfacher als ein ganz simples CNN von Grund auf neu zu definieren. Mich hat es trotzem interessiert, ob man auch mit bewusster Beschränkung auf primitivere historische Mittel, in unserem konkreten Fall auch das Auslangen finden könnte bzw. mit welchen Konsequenzen das verbunden wäre?

Aus diesem Grund hab ich die ganze Geschichte noch ein zweites mal mit anderen Mitteln umzusetzten versucht. Ganz zurück, zu völlig unausgegorenen Eigenkreationen od. CNNs der uralten LeNet-Architektur, die nur für unvergleichbar kleinere Bildauflösung vorgesehen war, wollte ich dabei aber auch nicht gehen. Ich hab daher als realistischen Kompromiss ein völlig untrainiertes AlexNet herangezogen. Eine Architektur, die zwar einerseits noch relativ überschaubar und einfach wirkt, aber trotzdem schon länger zur Klassifizierung ähnlicher Abbildungen herangezogen wird.

In [6]:
PATH='data/orchideen' 
sz=224 #32 
bs=32
In [19]:
data = OrchideenData.from_csv(PATH, orch_dir='orchideen',  
                              trn_csv='min50_training.csv', 
                              val_csv='min50_validation.csv',
                              tfms=tfms_from_stats(imagenet_stats, sz, aug_tfms=transforms_side_on, max_zoom=1.5))
In [20]:
model = torchvision.models.alexnet(pretrained=False)
In [9]:
model.classifier = nn.Sequential(*children(model.classifier)[:-1], 
                                nn.Linear(4096, data.c),
                                nn.LogSoftmax())
model
Out[9]:
AlexNet(
  (features): Sequential(
    (0): Conv2d (3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
    (3): Conv2d (64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
    (6): Conv2d (192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d (384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_features=4096)
    (5): ReLU(inplace)
    (6): Linear(in_features=4096, out_features=13)
    (7): LogSoftmax()
  )
)
In [10]:
learn = ConvLearner.from_model_data(model, data)

Das Lernen dauert in diesem Fall natürlich beträchtlich länger!

Trotzdem lässt sich auf diesem Weg nach ca. zwei Stunden Rechenzeit ein Modell gewinnen, das dem oben genutzten in unserem konkreten Fall nur unwesentlich nachsteht.

Das ist vorallem auch deshalb eine recht positive Erfahrung, weil es ja nicht immer Sinn macht, mit völlig überdimmnsionierten Modellen zu arbeiten. Speziell, wenn man an die Verwendung auf kleinen Bastelcomputern, Handys oder Download und Ausführung direkt im WebBrowser denkt, haben derartige deutlich schlankere Lösungen auch ihren Reiz.

In [11]:
%time learn.fit(0.01, 100)
[0.      2.35193 2.26714 0.17773]                         
[1.      2.23907 2.20231 0.18555]                         
[2.      2.20186 2.19362 0.2849 ]                         
[3.      2.18122 2.21383 0.18555]                         
[4.      2.14985 2.15283 0.23333]                         
[5.      2.12558 2.18938 0.23294]                         
[6.      2.10134 2.20002 0.22943]                         
[7.      2.06027 1.97681 0.27253]                         
[8.      2.04262 1.89512 0.37565]                         
[9.      1.97603 1.91148 0.27135]                         
[10.       1.9179   1.79693  0.37031]                     
[11.       1.8145   1.66127  0.39466]                     
[12.       1.76291  1.83579  0.39062]                     
[13.       1.7018   1.49796  0.4737 ]                     
[14.       1.59387  1.40678  0.48021]                     
[15.       1.54726  1.46585  0.4905 ]                     
[16.       1.48664  1.46586  0.41185]                     
[17.       1.44927  1.38803  0.5224 ]                     
[18.       1.41789  1.3088   0.51536]                     
[19.       1.38069  1.32803  0.49596]                     
[20.       1.3508   1.33673  0.47891]                     
[21.       1.31967  1.29982  0.51615]                     
[22.       1.28381  1.29113  0.52292]                     
[23.       1.21429  1.06028  0.61133]                     
[24.       1.20558  1.12861  0.58021]                     
[25.       1.18626  1.12356  0.54518]                     
[26.       1.17092  1.3369   0.4763 ]                     
[27.       1.17109  1.11939  0.57213]                     
[28.       1.08522  1.393    0.46862]                     
[29.       1.04803  1.10663  0.54544]                     
[30.       0.99272  0.97654  0.6375 ]                      
[31.       0.98132  1.47321  0.50963]                      
[32.       0.95903  0.94641  0.61419]                      
[33.       0.96755  1.00522  0.58359]                      
[34.       0.92496  1.05894  0.57669]                      
[35.       0.96323  1.06304  0.55143]                      
[36.       0.89839  0.94097  0.63047]                      
[37.       0.94449  0.84961  0.69323]                      
[38.       0.88594  0.80417  0.68346]                      
[39.       0.87102  0.95973  0.6043 ]                      
[40.       0.851    1.09625  0.58112]                      
[41.       0.81827  1.13737  0.61719]                      
[42.       0.81123  0.74041  0.73268]                      
[43.       0.7916   0.75562  0.73086]                      
[44.       0.75999  0.83552  0.66341]                      
[45.       0.75681  1.22658  0.55482]                      
[46.       0.74097  1.03218  0.61328]                      
[47.       0.71115  1.10284  0.59596]                      
[48.       0.72575  0.81467  0.70703]                      
[49.       0.71454  0.89614  0.62083]                      
[50.       0.7138   0.93396  0.65013]                      
[51.       0.67992  0.96313  0.65977]                      
[52.       0.66471  0.97546  0.64805]                      
[53.       0.64008  0.92751  0.63516]                      
[54.       0.63909  0.88024  0.67917]                      
[55.       0.63963  1.15245  0.62305]                      
[56.       0.61512  0.86292  0.68763]                      
[57.       0.60124  0.86827  0.66771]                      
[58.       0.58709  0.72077  0.75781]                      
[59.       0.62366  0.68985  0.72565]                      
[60.       0.59532  0.6745   0.74596]                      
[61.       0.58632  0.76809  0.70286]                      
[62.       0.56237  0.97128  0.69102]                      
[63.       0.54478  0.65137  0.77904]                      
[64.       0.54881  0.7811   0.74167]                      
[65.       0.5351   0.91797  0.68112]                      
[66.       0.54641  0.70134  0.7793 ]                      
[67.       0.52298  0.84115  0.72057]                      
[68.       0.54966  0.83184  0.69935]                      
[69.       0.56372  1.05793  0.62656]                      
[70.       0.56535  0.70963  0.7224 ]                      
[71.       0.51146  0.83713  0.7125 ]                      
[72.       0.49952  0.83276  0.7207 ]                      
[73.       0.49117  0.68466  0.76237]                      
[74.       0.4892   0.86336  0.72148]                      
[75.       0.48008  0.84076  0.69727]                      
[76.       0.4449   0.70237  0.7582 ]                      
[77.       0.45315  0.86819  0.68958]                      
[78.       0.46774  0.97137  0.7151 ]                      
[79.       0.44471  0.72345  0.75026]                      
[80.       0.44702  0.9268   0.7168 ]                      
[81.       0.4389   0.81723  0.73073]                      
[82.       0.43411  0.77568  0.73867]                      
[83.       0.4248   1.04373  0.68867]                      
[84.       0.39792  0.74455  0.75156]                      
[85.       0.36711  0.99055  0.72253]                      
[86.       0.41475  1.0117   0.68177]                      
[87.       0.40114  0.83547  0.75612]                      
[88.       0.39389  0.82129  0.76758]                      
[89.       0.38407  0.90651  0.74857]                      
[90.       0.40229  0.91124  0.70716]                      
[91.       0.39865  0.78058  0.76224]                      
[92.       0.36953  0.75479  0.76185]                      
[93.       0.36349  0.7233   0.79505]                      
[94.       0.37917  0.76275  0.7918 ]                      
[95.       0.38822  0.74937  0.75833]                      
[96.       0.37964  0.75262  0.75898]                      
[97.       0.38147  0.75006  0.74427]                      
[98.       0.37896  0.98357  0.72318]                      
[99.       0.35446  0.83973  0.74102]                      

CPU times: user 1h 29min 27s, sys: 9min 33s, total: 1h 39min
Wall time: 27min 48s
In [12]:
learn.sched.plot_loss()
In [13]:
learn.save('min50-alexnet')
In [14]:
log_probs_tta,y = learn.TTA()
probs_tta = np.mean(np.exp(log_probs_tta),0)
preds_tta = np.argmax(probs_tta, axis=1)
                                             
In [15]:
classes = np.array(data.classes)
order = classes.argsort()
labels = classes[order]

print(classification_report(y, preds_tta, order, labels, digits=3))
cm = confusion_matrix(y, preds_tta, order)
out = plot_orchideen_confusion_matrix(cm, labels, y, order)
                            precision    recall  f1-score   support

            Ophrys apifera      0.800     1.000     0.889        16
           Ophrys dinarica      0.714     0.714     0.714        14
           Ophrys exaltata      0.857     0.545     0.667        11
         Ophrys fusca s.l.      0.917     1.000     0.957        11
    Ophrys holoserica s.l.      0.605     0.791     0.686        91
       Ophrys holoserica-3      0.860     0.533     0.658        92
          Ophrys incubacea      0.893     0.862     0.877        29
         Ophrys istriensis      0.832     0.832     0.832        95
Ophrys passionis garganica      0.778     0.778     0.778        18
      Ophrys rhodostephane      0.846     0.733     0.786        15
     Ophrys sphegodes s.l.      0.889     0.941     0.914        85
         Ophrys tommasinii      0.727     0.842     0.780        19
       Ophrys zinsmeisteri      0.615     0.667     0.640        12

               avg / total      0.797     0.781     0.778       508

TODO: Alle Exemplare lernen und mit Hybriden vergleichen

TODO: Nutzung der Modelle im WebBrowser mit Hilfe von WebDNN