Guides
Train a forest fire classifier
Start training satellite images classification models using Bluesight API
This guide shows how to train and infer Clay Foundation Model for classification task using our API.
Copy
API_URL = "https://api.bluesight.ai"
HEADERS = {"Content-Type": "application/json"}
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"
lat, lon = 37.30939, -8.57207
start, end = "2018-07-01", "2018-09-01"
Copy
!pip install pystac-client geopandas stackstac h5py requests-toolbelt
Downloading Data
This step completely copies data preparation step from the Clay example notebook and you probably don’t need it if you have your own data. If that’s the case, jump to the next section.
If you need more elaboration on what’s going on visit original source.
Copy
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystac_client
import stackstac
from rasterio.enums import Resampling
from shapely import Point
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
collections=[COLLECTION],
datetime=f"{start}/{end}",
bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
max_items=100,
query={"eo:cloud_cover": {"lt": 80}},
)
all_items = search.item_collection()
items = []
dates = []
for item in all_items:
if item.datetime.date() not in dates:
items.append(item)
dates.append(item.datetime.date())
print(f"Found {len(items)} items")
epsg = items[0].properties["proj:epsg"]
poidf = gpd.GeoDataFrame(
pd.DataFrame(),
crs="EPSG:4326",
geometry=[Point(lon, lat)],
).to_crs(epsg)
coords = poidf.iloc[0].geometry.coords[0]
size = 256
gsd = 10
bounds = (
coords[0] - (size * gsd) // 2,
coords[1] - (size * gsd) // 2,
coords[0] + (size * gsd) // 2,
coords[1] + (size * gsd) // 2,
)
stack = stackstac.stack(
items,
bounds=bounds,
snap_bounds=False,
epsg=epsg,
resolution=gsd,
dtype="float64",
rescale=False,
fill_value=0,
assets=["blue", "green", "red", "nir"],
resampling=Resampling.nearest,
)
stack = stack.compute()
stack.sel(band=["red", "green", "blue"]).plot.imshow(row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6)
plt.show()
Copy
Found 12 items
Copy
labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
train_idxs = [0, 1, 3, 4, 7, 8, 9]
test_idxs = [2, 5, 6, 10, 11]
Preparing Data
Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.
Copy
from pathlib import Path
import h5py
import numpy as np
from tqdm import tqdm
num_samples = len(train_idxs)
num_bands = len(stack[0].band.values)
pixel_shape = stack[0].values.shape
dt = np.dtype(
[
("bands", f"S10", (num_bands,)), # Adjust '10' if your band names are longer
("gsd", "float32"),
("pixels", "float32", pixel_shape),
("platform", "S20"), # Adjust '20' if your platform names are longer
("point", "float32", (2,)),
("timestamp", "int64"),
("label", "int64"),
]
)
output_file = Path("forest_fire_train_data.h5")
with h5py.File(output_file, "w") as f:
dataset = f.create_dataset(
"data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
)
for i, idx in enumerate(tqdm(train_idxs)):
image = stack[idx]
dataset[i] = (
[band.encode("ascii", "ignore") for band in image.band.values],
image.gsd.values.item(),
image.values.astype("float32"),
COLLECTION.encode("ascii", "ignore"),
[lat, lon],
image.time.values.astype("datetime64[s]").astype(np.int64).item(),
labels[idx].item(),
)
Copy
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 49.55it/s]
Uploading Data to Bluesight
Copy
import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder
def upload_file(file_path):
file_name = os.path.basename(file_path)
encoder = MultipartEncoder(
fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
)
response = requests.post(
f"{API_URL}/files",
data=encoder,
headers={'Content-Type': encoder.content_type}
)
if response.status_code == 200:
print("File uploaded successfully!")
return response.json()
else:
print(f"Failed to upload file. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
Copy
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
Copy
File uploaded successfully!
{'id': 'file-lw3zjxrg',
'bytes': 3669052,
'created_at': 1722451315,
'filename': 'forest_fire_train_data.h5'}
CPU times: user 16.5 ms, sys: 7.79 ms, total: 24.3 ms
Wall time: 6.66 s
Starting Training
Copy
def create_training_job(task, training_file, validation_file=None):
payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
response = requests.post(
f"{API_URL}/training/jobs",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
print("Training job submitted successfully!")
return response.json()
else:
print(f"Failed to sumbit training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
Copy
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
Copy
Training job submitted successfully!
{'task': 'classification',
'training_file': 'file-lw3zjxrg',
'validation_file': None,
'hyperparameters': None,
'id': 'trainingjob-5ms9ja6h',
'created_at': 1722451316,
'status': 'initializing',
'error': None,
'trained_model': None,
'finished_at': None}
Copy
def retrieve_training_job(job):
response = requests.get(
f"{API_URL}/training/jobs/{job}"
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to retrieve training job. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
Copy
%%time
import time
status = None
while status not in ("succeeded", "failed", "cancelled"):
training_job = retrieve_training_job(training_job["id"])
status = training_job["status"]
print(status)
time.sleep(3)
pprint(training_job)
Copy
initializing
initializing
initializing
running
succeeded
{'created_at': 1722451316,
'error': None,
'finished_at': 1722451329,
'hyperparameters': None,
'id': 'trainingjob-5ms9ja6h',
'status': 'succeeded',
'task': 'classification',
'trained_model': 'model:classification-lljzufea',
'training_file': 'file-lw3zjxrg',
'validation_file': None}
CPU times: user 26.1 ms, sys: 4.69 ms, total: 30.8 ms
Wall time: 18.2 s
Running Inference
For inference we need only images.
Copy
X_test = []
for image in stack[test_idxs]:
X_test.append(
{
"bands": image.band.values.tolist(),
"gsd": image.gsd.values.item(),
"pixels": image.values.tolist(),
"platform": COLLECTION,
"point": [lat, lon],
"timestamp": image.time.values.astype('datetime64[s]').astype(np.int64).item(),
}
)
Copy
def run_trained_model_inference(model, images):
payload = {"model": model, "images": images}
response = requests.post(
f"{API_URL}/inference",
json=payload,
headers={'Content-Type': "application/json"}
)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to run inference. Status code: {response.status_code}")
print(f"Response: {response.text}")
return None
Copy
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
pprint(y_pred)
Copy
{'labels': [0, 1, 1, 2, 2]}
CPU times: user 256 ms, sys: 23.9 ms, total: 279 ms
Wall time: 11.3 s
Evaluating results
Copy
y_true = labels[test_idxs]
y_pred = np.array(y_pred["labels"])
Copy
match = np.sum(y_true == y_pred)
print(f"Matched {match} out of {len(y_true)} correctly")
Copy
Matched 5 out of 5 correctly
Assistant
Responses are generated using AI and may contain mistakes.