Guides
Train a ship classifier
Guides
Train a ship classifier
Start training satellite images classification models using Bluesight API
This guide shows how to train and infer Clay Foundation Model for classification task (ships vs non ships dataset) using our API.
API_URL = "https://api.bluesight.ai"
GSD = 3.0
IMG_WIDTH = 80
BANDS = ["red", "green", "blue"]
DATA_URL = "https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB"
DATA_OUTPUT = "shipsnet.json"
RANDOM_SEED = 42
!pip install scikit-learn gdown h5py requests-toolbelt
Downloading Data
import json
from collections import Counter
import gdown
import numpy as np
gdown.download(DATA_URL, DATA_OUTPUT, quiet=False)
with open(DATA_OUTPUT, "rb") as f:
data = json.load(f)
print(f"Keys: {list(data.keys())}")
data, labels, locations = data["data"], data["labels"], data["locations"]
print(f"Labels distribution: {dict(Counter(labels))}")
pixels = np.array(data).astype("uint8")
pixels = pixels.reshape(-1, 3, IMG_WIDTH, IMG_WIDTH)
Downloading...
From (original): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB
From (redirected): https://drive.google.com/uc?id=1f-M8N_R0nA4n6bDV1Hld31CnbR_rinjB&confirm=t&uuid=61a02ee9-af08-4bc7-a1ed-3ede065033d5
To: /home/furiousteabag/Projects/docs/guides/shipsnet.json
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344M/344M [00:32<00:00, 10.5MB/s]
``````output
Keys: ['data', 'labels', 'locations', 'scene_ids']
Labels distribution: {1: 1000, 0: 3000}
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle("Ships (top row) vs Non-Ships (bottom row)")
# Plot ships (label 1)
ship_indices = [i for i, label in enumerate(labels) if label == 1]
for i, ax in enumerate(axes[0]):
ax.imshow(pixels[ship_indices[i]].transpose([1, 2, 0]))
ax.axis('off')
ax.set_title(f"Ship {i+1}")
# Plot non-ships (label 0)
non_ship_indices = [i for i, label in enumerate(labels) if label == 0]
for i, ax in enumerate(axes[1]):
ax.imshow(pixels[non_ship_indices[i]].transpose([1, 2, 0]))
ax.axis('off')
ax.set_title(f"Non-Ship {i+1}")
plt.tight_layout()
plt.show()
Preparing Data
from datetime import datetime
from sklearn.model_selection import train_test_split
images = []
for image, (lon, lat) in zip(pixels, locations):
images.append(
{
"bands": BANDS,
"gsd": GSD,
"pixels": image.tolist(),
"point": [lat, lon]
}
)
X_train, X_test, y_train, y_test = train_test_split(
images, labels, test_size=0.2, stratify=labels, random_state=RANDOM_SEED
)
print(f"Training set size: {len(X_train)}")
print(f"Test set shape: {len(X_test)}")
print(f"Training labels shape: {len(y_train)}")
print(f"Test labels shape: {len(y_test)}")
Training set size: 3200
Test set shape: 800
Training labels shape: 3200
Test labels shape: 800
from pathlib import Path
import h5py
import numpy as np
from tqdm import tqdm
num_samples = len(X_train)
num_bands = len(BANDS)
pixel_shape = (num_bands, IMG_WIDTH, IMG_WIDTH)
dt = np.dtype(
[
("bands", f"S10", (num_bands,)), # Adjust '10' if your band names are longer
("gsd", "float32"),
("pixels", "float32", pixel_shape),
("platform", "S20"), # Adjust '20' if your platform names are longer
("point", "float32", (2,)),
("timestamp", "int64"),
("label", "int64"),
]
)
output_file = Path("ships_train_data.h5")
with h5py.File(output_file, "w") as f:
dataset = f.create_dataset(
"data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
)
for i, (X, y) in tqdm(enumerate(zip(X_train, y_train))):
dataset[i] = (
[band.encode("ascii", "ignore") for band in BANDS],
GSD,
np.array(X["pixels"]).astype("float32"),
"".encode("ascii", "ignore"),
np.array(X["point"]).astype("int64"),
0,
np.array(y).astype("int64")
)
3200it [00:05, 542.14it/s]
Uploading Data to Bluesight
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder
def upload_file(file_path):
file_name = os.path.basename(file_path)
encoder = MultipartEncoder(
fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
)
response = requests.post(
f"{API_URL}/files",
data=encoder,
headers={'Content-Type': encoder.content_type}
)
if response.status_code == 200:
print("File uploaded successfully!")
return response.json()
else:
print(f"Failed to upload file. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
File uploaded successfully!
{'id': 'file-v54z7pws',
'bytes': 72732113,
'created_at': 1722454847,
'filename': 'ships_train_data.h5'}
CPU times: user 394 ms, sys: 90.4 ms, total: 485 ms
Wall time: 35.8 s
Starting Training
Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.
def create_training_job(task, training_file, validation_file=None):
payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
response = requests.post(
f"{API_URL}/training/jobs",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
print("Training job submitted successfully!")
return response.json()
else:
print(f"Failed to sumbit training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
Training job submitted successfully!
{'task': 'classification',
'training_file': 'file-v54z7pws',
'validation_file': None,
'hyperparameters': None,
'id': 'trainingjob-o4rqy3iw',
'created_at': 1722454848,
'status': 'initializing',
'error': None,
'trained_model': None,
'finished_at': None}
def retrieve_training_job(job):
response = requests.get(
f"{API_URL}/training/jobs/{job}"
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to retrieve training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
import time
status = None
while status not in ("succeeded", "failed", "cancelled"):
training_job = retrieve_training_job(training_job["id"])
status = training_job["status"]
print(status)
time.sleep(3)
pprint(training_job)
initializing
initializing
initializing
running
running
running
running
running
running
running
succeeded
{'created_at': 1722454848,
'error': None,
'finished_at': 1722454883,
'hyperparameters': None,
'id': 'trainingjob-o4rqy3iw',
'status': 'succeeded',
'task': 'classification',
'trained_model': 'model:classification-xvxor2y2',
'training_file': 'file-v54z7pws',
'validation_file': None}
CPU times: user 76.4 ms, sys: 22.5 ms, total: 98.9 ms
Wall time: 39.1 s
Running Inference
For inference we need only images.
def run_trained_model_inference(model, images):
payload = {"model": model, "images": images}
response = requests.post(
f"{API_URL}/inference",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to run inference. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
CPU times: user 1.69 s, sys: 687 ms, total: 2.37 s
Wall time: 35.2 s
from sklearn.metrics import classification_report, accuracy_score
y_pred = y_pred["labels"]
match = np.sum(np.array(y_test) == np.array(y_pred))
print(f"Matched {match} out of {len(y_test)} correctly")
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(classification_report(y_test, y_pred))
Matched 779 out of 800 correctly
Accuracy: 0.97375
precision recall f1-score support
0 0.98 0.98 0.98 600
1 0.94 0.95 0.95 200
accuracy 0.97 800
macro avg 0.96 0.97 0.97 800
weighted avg 0.97 0.97 0.97 800