Aprender de los datos de los gráficos usando Keras y Tensorflow

0 2

Motivación :

Hay muchos datos que pueden ser representados en forma de un gráfico en aplicaciones del mundo real como en las redes de citas, redes sociales (gráfico de seguidores, red de amigos,…), redes biológicas o telecomunicaciones. El uso de funciones de extracción de gráficos puede mejorar el rendimiento de los modelos predictivos al confiar en el flujo de información entre nodos vecinos. Sin embargo, la representación de los datos de los gráficos no es sencilla, especialmente si no pretendemos implementar características artesanales, ya que la mayoría de los modelos ML esperan una entrada de tamaño fijo o lineal, lo que no es el caso de los datos de los gráficos, en este artículo exploraremos algunas formas de tratar con los gráficos genéricos para realizar una clasificación de nodos basada en representaciones gráficas aprendidas directamente de los datos.

Conjunto de datos :

El conjunto de datos de la red de citación Cora servirá de base para las implementaciones y experimentos a lo largo de este puesto. Cada nodo representa un documento científico y los bordes entre nodos representan una relación de citación entre los dos documentos. Cada nodo está representado por un conjunto de características binarias (“bolsa de palabras”), así como por un conjunto de bordes que lo vinculan con otros nodos. Cada Nodo también está representado por características de palabra binaria que indican la presencia de una palabra correspondiente. En total hay 1433 características binarias (Sparse) para cada nodo. En lo que sigue sólo utilizamos 140 muestras para el entrenamiento y el resto para la prueba de validación.

Configuración del problema :

Problema : Asignar una etiqueta de clase a los nodos de un gráfico con pocas muestras de entrenamiento. Hipótesis de Intuición : Los nodos que están cerca en el gráfico tienen más probabilidades de tener etiquetas similares. Solución Encuentre una manera de extraer características del gráfico para ayudar a clasificar los nuevos nodos.

Enfoque propuesto:

Primero experimentamos con el modelo más simple que aprende a predecir clases de nodos usando sólo las características binarias y descartando toda la información de las gráficas, este modelo es una Red Neural totalmente conectada que toma como entrada las características binarias y produce las probabilidades de clase para cada nodo.

Modelo base Precisión : 53,28%.

Esta es la precisión inicial que trataremos de mejorar añadiendo características basadas en gráficos.

Adición de funciones de gráficos :

Una forma de aprender automáticamente las características de las gráficas es incrustar cada nodo en un vector entrenando a una red en la tarea auxiliar de predecir la inversa de la longitud del trayecto más corto entre dos nodos de entrada, como se detalla en la figura y en el fragmento de código que se muestra a continuación:

El siguiente paso es utilizar la integración de nodos preentrenados como entrada al modelo de clasificación. También añadimos una entrada adicional que es el promedio de las características binarias de los nodos vecinos usando la distancia de los vectores de inserción aprendidos.

La red de clasificación resultante se describe en la figura siguiente:

Incrustación de gráficos >73,06%.

Podemos ver que la adición de características de gráficos aprendidos como entrada a la > 53,28% a 73,06%

Mejora de la función de aprendizaje de gráficos :

Podemos buscar mejorar aún más el modelo anterior empujando el pre-entrenamiento más allá y usando las características binarias en la red de incrustación de nodos y luego reutilizando los pesos pre-entrenados de las características binarias además del vector de incrustación de nodos. Esto resulta en un modelo que se basa en representaciones más útiles de las características binarias aprendidas de la estructura del gráfico.

Mejoras en la incrustación de gráficos >76,35%.

Esta mejora adicional añade un poco de precisión en comparación con el enfoque anterior.

Conclusión :

En este post hemos visto que podemos aprender representaciones útiles a partir de datos estructurados de gráficos y luego utilizar estas representaciones para mejorar el rendimiento de generalización de un nodo > 53,28% a 76,35%

El código para reproducir los resultados está disponible aquí : https:/github.comCVxTzgraph_classification

Siéntase libre de comentar si tiene alguna sugerencia o si necesita algunos consejos para ejecutar el código en su máquina

Leave A Reply

Your email address will not be published.