This guide shows how to train and infer Clay Foundation Model for classification task using our API.

Open In Colab
API_URL = "https://api.bluesight.ai"
HEADERS = {"Content-Type": "application/json"}

STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

lat, lon = 37.30939, -8.57207
start, end = "2018-07-01", "2018-09-01"
!pip install pystac-client geopandas stackstac h5py requests-toolbelt

Downloading Data

This step completely copies data preparation step from the Clay example notebook and you probably don’t need it if you have your own data. If that’s the case, jump to the next section.

If you need more elaboration on what’s going on visit original source.

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystac_client
import stackstac
from rasterio.enums import Resampling
from shapely import Point

catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.item_collection()

items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")

epsg = items[0].properties["proj:epsg"]

poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(lon, lat)],
).to_crs(epsg)

coords = poidf.iloc[0].geometry.coords[0]

size = 256
gsd = 10
bounds = (
    coords[0] - (size * gsd) // 2,
    coords[1] - (size * gsd) // 2,
    coords[0] + (size * gsd) // 2,
    coords[1] + (size * gsd) // 2,
)

stack = stackstac.stack(
    items,
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=gsd,
    dtype="float64",
    rescale=False,
    fill_value=0,
    assets=["blue", "green", "red", "nir"],
    resampling=Resampling.nearest,
)

stack = stack.compute()

stack.sel(band=["red", "green", "blue"]).plot.imshow(row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6)
plt.show()
Found 12 items

labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])

train_idxs = [0, 1, 3, 4, 7, 8, 9]
test_idxs = [2, 5, 6, 10, 11]

Preparing Data

Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.

from pathlib import Path

import h5py
import numpy as np
from tqdm import tqdm

num_samples = len(train_idxs)
num_bands = len(stack[0].band.values)
pixel_shape = stack[0].values.shape

dt = np.dtype(
    [
        ("bands", f"S10", (num_bands,)),  # Adjust '10' if your band names are longer
        ("gsd", "float32"),
        ("pixels", "float32", pixel_shape),
        ("platform", "S20"),  # Adjust '20' if your platform names are longer
        ("point", "float32", (2,)),
        ("timestamp", "int64"),
        ("label", "int64"),
    ]
)


output_file = Path("forest_fire_train_data.h5")
with h5py.File(output_file, "w") as f:

    dataset = f.create_dataset(
        "data", (num_samples,), dtype=dt, compression="gzip", compression_opts=1, chunks=(min(1, num_samples),)
    )

    for i, idx in enumerate(tqdm(train_idxs)):
        image = stack[idx]

        dataset[i] = (
            [band.encode("ascii", "ignore") for band in image.band.values],
            image.gsd.values.item(),
            image.values.astype("float32"),
            COLLECTION.encode("ascii", "ignore"),
            [lat, lon],
            image.time.values.astype("datetime64[s]").astype(np.int64).item(),
            labels[idx].item(),
        )
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 49.55it/s]

Uploading Data to Bluesight

import requests
import os
from pprint import pprint
from requests_toolbelt.multipart.encoder import MultipartEncoder

def upload_file(file_path):
    file_name = os.path.basename(file_path)
    encoder = MultipartEncoder(
        fields={'file': (file_name, open(file_path, 'rb'), 'application/octet-stream')}
    )
    
    response = requests.post(
        f"{API_URL}/files",
        data=encoder,
        headers={'Content-Type': encoder.content_type}
    )
    
    if response.status_code == 200:
        print("File uploaded successfully!")
        return response.json()
    else:
        print(f"Failed to upload file. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
%%time
file_object = upload_file(output_file)
pprint(file_object, sort_dicts=False)
File uploaded successfully!
{'id': 'file-lw3zjxrg',
 'bytes': 3669052,
 'created_at': 1722451315,
 'filename': 'forest_fire_train_data.h5'}
CPU times: user 16.5 ms, sys: 7.79 ms, total: 24.3 ms
Wall time: 6.66 s

Starting Training

def create_training_job(task, training_file, validation_file=None):

    payload = {"task": task, "training_file": training_file, "validation_file": validation_file}
    
    response = requests.post(
        f"{API_URL}/training/jobs",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        print("Training job submitted successfully!")
        return response.json()
    else:
        print(f"Failed to sumbit training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
training_job = create_training_job("classification", file_object["id"])
pprint(training_job, sort_dicts=False)
Training job submitted successfully!
{'task': 'classification',
 'training_file': 'file-lw3zjxrg',
 'validation_file': None,
 'hyperparameters': None,
 'id': 'trainingjob-5ms9ja6h',
 'created_at': 1722451316,
 'status': 'initializing',
 'error': None,
 'trained_model': None,
 'finished_at': None}
def retrieve_training_job(job):
    response = requests.get(
        f"{API_URL}/training/jobs/{job}"
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to retrieve training job. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
%%time
import time

status = None
while status not in ("succeeded", "failed", "cancelled"):
    training_job = retrieve_training_job(training_job["id"])
    status = training_job["status"]
    print(status)
    time.sleep(3)
pprint(training_job)
initializing
initializing
initializing
running
succeeded
{'created_at': 1722451316,
 'error': None,
 'finished_at': 1722451329,
 'hyperparameters': None,
 'id': 'trainingjob-5ms9ja6h',
 'status': 'succeeded',
 'task': 'classification',
 'trained_model': 'model:classification-lljzufea',
 'training_file': 'file-lw3zjxrg',
 'validation_file': None}
CPU times: user 26.1 ms, sys: 4.69 ms, total: 30.8 ms
Wall time: 18.2 s

Running Inference

For inference we need only images.

X_test = []
for image in stack[test_idxs]:
    X_test.append(
        {
            "bands": image.band.values.tolist(),
            "gsd": image.gsd.values.item(),
            "pixels": image.values.tolist(),
            "platform": COLLECTION,
            "point": [lat, lon],
            "timestamp": image.time.values.astype('datetime64[s]').astype(np.int64).item(),
        }
    )
def run_trained_model_inference(model, images):
    payload = {"model": model, "images": images}
    
    response = requests.post(
        f"{API_URL}/inference",
        json=payload,
        headers={'Content-Type': "application/json"}
    )
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to run inference. Status code: {response.status_code}")
        print(f"Response: {response.text}")
        return None
%%time
y_pred = run_trained_model_inference(training_job["trained_model"], X_test)
pprint(y_pred)
{'labels': [0, 1, 1, 2, 2]}
CPU times: user 256 ms, sys: 23.9 ms, total: 279 ms
Wall time: 11.3 s

Evaluating results

y_true = labels[test_idxs]
y_pred = np.array(y_pred["labels"])
match = np.sum(y_true == y_pred)
print(f"Matched {match} out of {len(y_true)} correctly")
Matched 5 out of 5 correctly