Start training satellite images classification models using Bluesight API
API_URL = "https://api.bluesight.ai"
GSD = 3.0
IMG_WIDTH = 80
BANDS = ["red", "green", "blue"]
DATA_URL = "https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB"
DATA_OUTPUT = "shipsnet.json"
RANDOM_SEED = 42
!pip install scikit-learn gdown h5py requests-toolbelt
import json
from collections import Counter
import gdown
import numpy as np
gdown.download(DATA_URL, DATA_OUTPUT, quiet=False)
with open(DATA_OUTPUT, "rb") as f:
data = json.load(f)
print(f"Keys: {list(data.keys())}")
data, labels, locations = data["data"], data["labels"], data["locations"]
print(f"Labels distribution: {dict(Counter(labels))}")
pixels = np.array(data).astype("uint8")
pixels = pixels.reshape(-1, 3, IMG_WIDTH, IMG_WIDTH)
Downloading...
From (original): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB
From (redirected): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB&confirm=t&uuid=61a02ee9-af08-4bc7-a1ed-3ede065033d5
To: /home/furiousteabag/Projects/docs/guides/shipsnet.json
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344M/344M [00:32<00:00, 10.5MB/s]
``````output
Keys: ['data', 'labels', 'locations', 'scene_ids']
Labels distribution: {1: 1000, 0: 3000}
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Ships (top row) vs Non-Ships (bottom row)")
# Plot ships (label 1)
ship_indices = [i for i, label in enumerate(labels) if label == 1]
for i, ax in enumerate(axes[0]):
ax.imshow(pixels[ship_indices[i]].transpose([1, 2, 0]))
ax.axis('off')
ax.set_title(f"Ship {i+1}")
# Plot non-ships (label 0)
non_ship_indices = [i for i, label in enumerate(labels) if label == 0]
for i, ax in enumerate(axes[1]):
ax.imshow(pixels[non_ship_indices[i]].transpose([1, 2, 0]))
ax.axis('off')
ax.set_title(f"Non-Ship {i+1}")
plt.tight_layout()
plt.show()
from datetime import datetime
from sklearn.model_selection import train_test_split
images = []
for image, (lon, lat) in zip(pixels, locations):
images.append(
{
"bands": BANDS,
"gsd": GSD,
"pixels": image.tolist(),
"point": [lat, lon]
}
)
X_train, X_test, y_train, y_test = train_test_split(
images, labels, test_size=0.2, stratify=labels, random_state=RANDOM_SEED
)
print(f"Training set size: {len(X_train)}")
print(f"Test set shape: {len(X_test)}")
print(f"Training labels shape: {len(y_train)}")
print(f"Test labels shape: {len(y_test)}")
Training set size: 3200
Test set shape: 800
Training labels shape: 3200
Test labels shape: 800
from pathlib import Path
import h5py
import numpy as np
from tqdm import tqdm
num_samples = len(X_train)
num_bands = len(BANDS)
pixel_shape = (num_bands, IMG_WIDTH, IMG_WIDTH)
dt = np.dtype(
[
("bands", f"S10", (num_bands,)), # Adjust '10' if your band names are longer
("gsd", "float32"),
("pixels", "float32", pixel_shape),
("platform", "S20"), # Adjust '20' if your platform names are longer
("point", "float32", (2,)),
("timestamp", "int64"),
("label", "int64"),
]
)
output_file = Path("ships_train_data.h5")
with h5py.File(output_file, "w") as f:
dataset = f.create_dataset(
"data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
)
for i, (X, y) in tqdm(enumerate(zip(X_train, y_train))):
dataset[i] = (
[band.encode("ascii", "ignore") for band in BANDS],
GSD,
np.array(X["pixels"]).astype("float32"),
"".encode("ascii", "ignore"),
np.array(X["point"]).astype("int64"),
0,
np.array(y).astype("int64")
)
3200it [00:05, 542.14it/s]
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder
def upload_file(file_path):
file_name = os.path.basename(file_path)
encoder = MultipartEncoder(
fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
)
response = requests.post(
f"{API_URL}/files",
data=encoder,
headers={'Content-Type': encoder.content_type}
)
if response.status_code == 200:
print("File uploaded successfully!")
return response.json()
else:
print(f"Failed to upload file. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
File uploaded successfully!
{'id': 'file-v54z7pws',
'bytes': 72732113,
'created_at': 1722454847,
'filename': 'ships_train_data.h5'}
CPU times: user 394 ms, sys: 90.4 ms, total: 485 ms
Wall time: 35.8 s
def create_training_job(task, training_file, validation_file=None):
payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
response = requests.post(
f"{API_URL}/training/jobs",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
print("Training job submitted successfully!")
return response.json()
else:
print(f"Failed to sumbit training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
Training job submitted successfully!
{'task': 'classification',
'training_file': 'file-v54z7pws',
'validation_file': None,
'hyperparameters': None,
'id': 'trainingjob-o4rqy3iw',
'created_at': 1722454848,
'status': 'initializing',
'error': None,
'trained_model': None,
'finished_at': None}
def retrieve_training_job(job):
response = requests.get(
f"{API_URL}/training/jobs/{job}"
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to retrieve training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
import time
status = None
while status not in ("succeeded", "failed", "cancelled"):
training_job = retrieve_training_job(training_job["id"])
status = training_job["status"]
print(status)
time.sleep(3)
pprint(training_job)
initializing
initializing
initializing
running
running
running
running
running
running
running
succeeded
{'created_at': 1722454848,
'error': None,
'finished_at': 1722454883,
'hyperparameters': None,
'id': 'trainingjob-o4rqy3iw',
'status': 'succeeded',
'task': 'classification',
'trained_model': 'model:classification-xvxor2y2',
'training_file': 'file-v54z7pws',
'validation_file': None}
CPU times: user 76.4 ms, sys: 22.5 ms, total: 98.9 ms
Wall time: 39.1 s
def run_trained_model_inference(model, images):
payload = {"model": model, "images": images}
response = requests.post(
f"{API_URL}/inference",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to run inference. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
CPU times: user 1.69 s, sys: 687 ms, total: 2.37 s
Wall time: 35.2 s
from sklearn.metrics import classification_report, accuracy_score
y_pred = y_pred["labels"]
match = np.sum(np.array(y_test) == np.array(y_pred))
print(f"Matched {match} out of {len(y_test)} correctly")
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(classification_report(y_test, y_pred))
Matched 779 out of 800 correctly
Accuracy: 0.97375
precision recall f1-score support
0 0.98 0.98 0.98 600
1 0.94 0.95 0.95 200
accuracy 0.97 800
macro avg 0.96 0.97 0.97 800
weighted avg 0.97 0.97 0.97 800