Start training satellite images classification models using Bluesight API
API_URL = "https://api.bluesight.ai"
HEADERS = {"Content-Type": "application/json"}
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"
lat, lon = 37.30939, -8.57207
start, end = "2018-07-01", "2018-09-01"
!pip install pystac-client geopandas stackstac h5py requests-toolbelt
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystac_client
import stackstac
from rasterio.enums import Resampling
from shapely import Point
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
collections=[COLLECTION],
datetime=f"{start}/{end}",
bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
max_items=100,
query={"eo:cloud_cover": {"lt": 80}},
)
all_items = search.item_collection()
items = []
dates = []
for item in all_items:
if item.datetime.date() not in dates:
items.append(item)
dates.append(item.datetime.date())
print(f"Found {len(items)} items")
epsg = items[0].properties["proj:epsg"]
poidf = gpd.GeoDataFrame(
pd.DataFrame(),
crs="EPSG:4326",
geometry=[Point(lon, lat)],
).to_crs(epsg)
coords = poidf.iloc[0].geometry.coords[0]
size = 256
gsd = 10
bounds = (
coords[0] - (size * gsd) // 2,
coords[1] - (size * gsd) // 2,
coords[0] + (size * gsd) // 2,
coords[1] + (size * gsd) // 2,
)
stack = stackstac.stack(
items,
bounds=bounds,
snap_bounds=False,
epsg=epsg,
resolution=gsd,
dtype="float64",
rescale=False,
fill_value=0,
assets=["blue", "green", "red", "nir"],
resampling=Resampling.nearest,
)
stack = stack.compute()
stack.sel(band=["red", "green", "blue"]).plot.imshow(row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6)
plt.show()
Found 12 items
labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
train_idxs = [0, 1, 3, 4, 7, 8, 9]
test_idxs = [2, 5, 6, 10, 11]
from pathlib import Path
import h5py
import numpy as np
from tqdm import tqdm
num_samples = len(train_idxs)
num_bands = len(stack[0].band.values)
pixel_shape = stack[0].values.shape
dt = np.dtype(
[
("bands", f"S10", (num_bands,)), # Adjust '10' if your band names are longer
("gsd", "float32"),
("pixels", "float32", pixel_shape),
("platform", "S20"), # Adjust '20' if your platform names are longer
("point", "float32", (2,)),
("timestamp", "int64"),
("label", "int64"),
]
)
output_file = Path("forest_fire_train_data.h5")
with h5py.File(output_file, "w") as f:
dataset = f.create_dataset(
"data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
)
for i, idx in enumerate(tqdm(train_idxs)):
image = stack[idx]
dataset[i] = (
[band.encode("ascii", "ignore") for band in image.band.values],
image.gsd.values.item(),
image.values.astype("float32"),
COLLECTION.encode("ascii", "ignore"),
[lat, lon],
image.time.values.astype("datetime64[s]").astype(np.int64).item(),
labels[idx].item(),
)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 49.55it/s]
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder
def upload_file(file_path):
file_name = os.path.basename(file_path)
encoder = MultipartEncoder(
fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
)
response = requests.post(
f"{API_URL}/files",
data=encoder,
headers={'Content-Type': encoder.content_type}
)
if response.status_code == 200:
print("File uploaded successfully!")
return response.json()
else:
print(f"Failed to upload file. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
File uploaded successfully!
{'id': 'file-lw3zjxrg',
'bytes': 3669052,
'created_at': 1722451315,
'filename': 'forest_fire_train_data.h5'}
CPU times: user 16.5 ms, sys: 7.79 ms, total: 24.3 ms
Wall time: 6.66 s
def create_training_job(task, training_file, validation_file=None):
payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
response = requests.post(
f"{API_URL}/training/jobs",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
print("Training job submitted successfully!")
return response.json()
else:
print(f"Failed to sumbit training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
Training job submitted successfully!
{'task': 'classification',
'training_file': 'file-lw3zjxrg',
'validation_file': None,
'hyperparameters': None,
'id': 'trainingjob-5ms9ja6h',
'created_at': 1722451316,
'status': 'initializing',
'error': None,
'trained_model': None,
'finished_at': None}
def retrieve_training_job(job):
response = requests.get(
f"{API_URL}/training/jobs/{job}"
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to retrieve training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
import time
status = None
while status not in ("succeeded", "failed", "cancelled"):
training_job = retrieve_training_job(training_job["id"])
status = training_job["status"]
print(status)
time.sleep(3)
pprint(training_job)
initializing
initializing
initializing
running
succeeded
{'created_at': 1722451316,
'error': None,
'finished_at': 1722451329,
'hyperparameters': None,
'id': 'trainingjob-5ms9ja6h',
'status': 'succeeded',
'task': 'classification',
'trained_model': 'model:classification-lljzufea',
'training_file': 'file-lw3zjxrg',
'validation_file': None}
CPU times: user 26.1 ms, sys: 4.69 ms, total: 30.8 ms
Wall time: 18.2 s
X_test = []
for image in stack[test_idxs]:
X_test.append(
{
"bands": image.band.values.tolist(),
"gsd": image.gsd.values.item(),
"pixels": image.values.tolist(),
"platform": COLLECTION,
"point": [lat, lon],
"timestamp": image.time.values.astype('datetime64[s]').astype(np.int64).item(),
}
)
def run_trained_model_inference(model, images):
payload = {"model": model, "images": images}
response = requests.post(
f"{API_URL}/inference",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to run inference. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
pprint(y_pred)
{'labels': [0, 1, 1, 2, 2]}
CPU times: user 256 ms, sys: 23.9 ms, total: 279 ms
Wall time: 11.3 s
y_true = labels[test_idxs]
y_pred = np.array(y_pred["labels"])
match = np.sum(y_true == y_pred)
print(f"Matched {match} out of {len(y_true)} correctly")
Matched 5 out of 5 correctly