Meta-Learning & Multimodal Learning Library for TensorFlow 2.0
metaX library is a python library with deep neural networks and datasets for meta learning and multi-view learning base on Tensorflow 2.0.
We provide…
dataset/
data_generator.py (Omniglot, mini-ImageNet) (Completed)
KTS_data_generator.py (Completed)
FLOWER_data_generator.py (In progress)
KMSCOCO_data_generator.py (In progress)
KVQA_data_generator.py (In progress)
CropDisease.py (Completed)
EuroSAT.py (Completed)
ISIC.py (Completed)
ChestX.py (Completed)
data/
raw_data/
model/
LearningType.py
metric_based/
Relation_network.py (In progress)
Prototypical_network.py (In progress)
Siamese_network.py (Completed)
model_based/
MANN.py (Completed)
SNAIL.py
optimization_based/
MAML.py (Completed)
MetaSGD.py
Reptile.py (In progress)
heterogeneous_data_analysis/
image_text_embeding.py (In progress)
Vis_LSTM.py (In progress)
Modified_mCNN.py (In progress)
train.py
utils.py (accuracy, mse)
pip install metax
Results Among various optimization based meta-learning algorithms for few-shot learning, MAML(model-agnostic meta-learning) has been highly popular due to its great performance on several benchmaks. This idea is to establish a meta-learner that seeks an initialization useful for fast learning of different tasks, then adapt to specific tasks quickly and efficiently.
Usage
from metaX.model.optimization_based.MAML import ModelAgnosticMetaLearning
from metaX.model.optimization_based.MAML import OmniglotModel
from metaX.datasets import OmniglotDatabase
# 1. Preprocess the Dataset
database = OmniglotDatabase(
raw_data_address="dataset\raw_data\omniglot",
random_seed=47,
num_train_classes=1200,
num_val_classes=100)
# 2. Create the learner model
network_cls=OmniglotModel
# 3. Wrap the meta-learning method(MAML) on the learner model and dataset
maml = ModelAgnosticMetaLearning(args, database, network_cls)
# 4. Meta-Train
maml.meta_train(epochs = args.epochs)
# 5. Meta-Test
maml.meta_test(iterations = args.iterations)
# 6. Load the trained model
maml.load_model(epochs = args.epochs)
# 7. Predict with support set
print(maml.predict_with_support(meta_test_path='/dataset/data/omniglot/test'))