diff --git a/README.md b/README.md index 2b26870..3402313 100644 --- a/README.md +++ b/README.md @@ -1 +1,156 @@ -# pyhuntressphish \ No newline at end of file +# pyhuntress - An API library for Huntress SIEM and Huntress Managed SAT, written in Python + +pyHuntress is a full-featured, type annotated API client written in Python for the Huntress APIs. + +This library has been developed with the intention of making the Huntress APIs simple and accessible to non-coders while allowing experienced coders to utilize all features the API has to offer without the boilerplate. + +pyHuntress currently supports both Huntress SIEM and Huntress Managed SAT products. + +Features: +========= +- **100% API Coverage.** All endpoints and response models. +- **Non-coder friendly.** 100% annotated for full IDE auto-completion. Clients handle requests and authentication - just plug the right details in and go! +- **Fully annotated.** This library has a strong focus on type safety and type hinting. Models are declared and parsed using [Pydantic](https://github.com/pydantic/pydantic) + +pyHuntress is currently in **development**. + +Known Issues: +============= +- As this project is still a WIP, documentation or code commentary may not always align. +- Huntress Managed SAT is not built +- Pagination does not work + +Road Map: +============= +- Add Huntress Managed SAT Report + +How-to: +============= +- [Install](#install) +- [Initializing the API Clients](#initializing-the-api-clients) + - [Huntress Managed SAT](#huntress-managed-sat) + - [Huntress SIEM](#huntress-siem) +- [Working with Endpoints](#working-with-endpoints) + - [Get many](#get-many) + - [Get one](#get-one) + - [Get with params](#get-with-params) +- [Pagination](#pagination) +- [Contributing](#contributing) +- [Supporting the project](#supporting-the-project) + +# Install +Open a terminal and run ```pip install pyhuntress``` + +# Initializing the API Clients + +### Huntress Managed SAT +```python +from pyhuntress import HuntressSATAPIClient + +# init client +sat_api_client = HuntressSATAPIClient( + mycurricula.com, + # your api public key, + # your api private key, +) +``` + +### Huntress SIEM +```python +from pyhuntress import HuntressSIEMAPIClient + +# init client +siem_api_client = HuntressSIEMAPIClient( + # huntress siem url + # your api public key, + # your api private key, +) +``` + + +# Working with Endpoints +Endpoints are 1:1 to what's available for both the Huntress Managed SAT and Huntress SIEM. + +For more information, check out the following resources: +- [Huntress Managed SAT REST API Docs](https://support.meetgradient.com/huntress-managed-sat) +- [Huntress SIEM REST API Docs](https://api.huntress.io/docs) + +### Get many +```python +### Managed SAT ### + +# sends GET request to /company/companies endpoint +companies = manage_api_client.company.companies.get() + +### SIEM ### + +# sends GET request to /agents endpoint +agents = siem_api_client.agents.get() +``` + +### Get one +```python +### Managed SAT ### + +# sends GET request to /company/companies/{id} endpoint +company = sat_api_client.company.companies.id(250).get() + +### SIEM ### + +# sends GET request to /agents/{id} endpoint +agent = siem_api_client.agents.id(250).get() +``` + +### Get with params +```python +### Managed SAT ### + +# sends GET request to /company/companies with a conditions query string +conditional_company = sat_api_client.company.companies.get(params={ + 'conditions': 'company/id=250' +}) + +### SIEM ### +# sends GET request to /agents endpoint with a condition query string +conditional_agent = siem_api_client.clients.get(params={ + 'platform': 'windows' +}) +``` + +# Pagination +The Huntress SIEM API paginates data for performance reasons through the ```page``` and ```limit``` query parameters. ```limit``` is limited to a maximum of 500. + +To make working with paginated data easy, Endpoints that implement a GET response with an array also supply a ```paginated()``` method. Under the hood this wraps a GET request, but does a lot of neat stuff to make working with pages easier. + +Working with pagination +```python +# initialize a PaginatedResponse instance for /agents, starting on page 1 with a pageSize of 100 +paginated_agents = siem_api_client.agents.paginated(1,100) + +# access the data from the current page using the .data field +page_one_data = paginated_agents.data + +# if there's a next page, retrieve the next page worth of data +paginated_agents.get_next_page() + +# if there's a previous page, retrieve the previous page worth of data +paginated_agents.get_previous_page() + +# iterate over all companies on the current page +for agent in paginated_agents: + # ... do things ... + +# iterate over all companies in all pages +# this works by yielding every item on the page, then fetching the next page and continuing until there's no data left +for agent in paginated_agents.all(): + # ... do things ... +``` + +# Contributing +Contributions to the project are welcome. If you find any issues or have suggestions for improvement, please feel free to open an issue or submit a pull request. + +# Supporting the project +:heart: + +# Inspiration and Stolen Code +The premise behind this came from the [pyConnectWise](https://github.com/HealthITAU/pyconnectwise) package and I stole **most** of the code and adapted it to the Huntress API endpoints. \ No newline at end of file diff --git a/src/pyhuntress/__init__.py b/src/pyhuntress/__init__.py new file mode 100644 index 0000000..3619c51 --- /dev/null +++ b/src/pyhuntress/__init__.py @@ -0,0 +1,5 @@ +from pyhuntress.clients.managedsat_client import HuntressSATAPIClient +from pyhuntress.clients.siem_client import HuntressSIEMAPIClient + +__all__ = ["HuntressSATAPIClient", "HuntressSIEMAPIClient"] +__version__ = "0.6.1" diff --git a/src/pyhuntress/clients/__init__.py b/src/pyhuntress/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/clients/huntress_client.py b/src/pyhuntress/clients/huntress_client.py new file mode 100644 index 0000000..329f537 --- /dev/null +++ b/src/pyhuntress/clients/huntress_client.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import contextlib +import json +import warnings +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, cast + +import requests +from requests import Response +from requests.exceptions import Timeout + +from pyhuntress.config import Config +from pyhuntress.exceptions import ( + AuthenticationFailedException, + ConflictException, + MalformedRequestException, + MethodNotAllowedException, + NotFoundException, + ObjectExistsError, + PermissionsFailedException, + ServerError, +) + +if TYPE_CHECKING: + from pyhuntress.types import RequestData, RequestMethod, RequestParams + + +class HuntressClient(ABC): + config: Config = Config() + + @abstractmethod + def _get_headers(self) -> dict[str, str]: + pass + + @abstractmethod + def _get_url(self) -> str: + pass + + def _make_request( # noqa: C901 + self, + method: RequestMethod, + url: str, + data: RequestData | None = None, + params: RequestParams | None = None, + headers: dict[str, str] | None = None, + retry_count: int = 0, + stream: bool = False, # noqa: FBT001, FBT002 + ) -> Response: + """ + Make an API request using the specified method, endpoint, data, and parameters. + This function isn't intended for use outside of this class. + Please use the available CRUD methods as intended. + + Args: + method (str): The HTTP method to use for the request (e.g., GET, POST, PUT, etc.). + endpoint (str, optional): The endpoint to make the request to. + data (dict, optional): The request data to send. + params (dict, optional): The query parameters to include in the request. + + Returns: + The Response object (see requests.Response). + + Raises: + Exception: If the request returns a status code >= 400. + """ + + if not headers: + headers = self._get_headers() + + # I don't like having to cast the params to a dict, but it's the only way I can get mypy to stop complaining about the type. + # TypedDicts aren't compatible with the dict type and this is the best way I can think of to handle this. + if data: + response = requests.request( + method, + url, + headers=headers, + json=data, + params=cast(dict[str, Any], params or {}), + stream=stream, + ) + else: + response = requests.request( + method, + url, + headers=headers, + params=cast(dict[str, Any], params or {}), + stream=stream, + ) + if not response.ok: + with contextlib.suppress(json.JSONDecodeError): + details: dict = response.json() + if response.status_code == 400: # noqa: SIM102 (Expecting to handle other codes in the future) + if details.get("code") == "InvalidObject": + errors = details.get("errors", []) + if len(errors) > 1: + warnings.warn( + "Found multiple errors - we may be masking some important error details. Please submit a Github issue with response.status_code and response.content so we can improve this error handling.", + stacklevel=1, + ) + for error in errors: + if error.get("code") == "ObjectExists": + error.pop("code") # Don't need code in message + raise ObjectExistsError(response, extra_message=json.dumps(error, indent=4)) + + if response.status_code == 400: + raise MalformedRequestException(response) + if response.status_code == 401: + raise AuthenticationFailedException(response) + if response.status_code == 403: + raise PermissionsFailedException(response) + if response.status_code == 404: + raise NotFoundException(response) + if response.status_code == 405: + raise MethodNotAllowedException(response) + if response.status_code == 409: + raise ConflictException(response) + if response.status_code == 500: + # if timeout is mentioned anywhere in the response then we'll retry. + # Ideally we'd return immediately on any non-timeout errors (since + # retries won't help much there), but err towards classifying too much + # as retries instead of too little. + if "timeout" in (response.text + response.reason).lower(): + if retry_count < self.config.max_retries: + retry_count += 1 + return self._make_request(method, url, data, params, headers, retry_count) + raise Timeout(response=response) + raise ServerError(response) + + return response diff --git a/src/pyhuntress/clients/managedsat_client.py b/src/pyhuntress/clients/managedsat_client.py new file mode 100644 index 0000000..c6594c6 --- /dev/null +++ b/src/pyhuntress/clients/managedsat_client.py @@ -0,0 +1,175 @@ +import base64 +import typing + +from pyhuntress.clients.huntress_client import HuntressClient +from pyhuntress.config import Config + +if typing.TYPE_CHECKING: + from pyhuntress.endpoints.managedsat.CompanyEndpoint import CompanyEndpoint + from pyhuntress.endpoints.managedsat.ConfigurationsEndpoint import ConfigurationsEndpoint + from pyhuntress.endpoints.managedsat.ExpenseEndpoint import ExpenseEndpoint + from pyhuntress.endpoints.managedsat.FinanceEndpoint import FinanceEndpoint + from pyhuntress.endpoints.managedsat.MarketingEndpoint import MarketingEndpoint + from pyhuntress.endpoints.managedsat.ProcurementEndpoint import ProcurementEndpoint + from pyhuntress.endpoints.managedsat.ProjectEndpoint import ProjectEndpoint + from pyhuntress.endpoints.managedsat.SalesEndpoint import SalesEndpoint + from pyhuntress.endpoints.managedsat.ScheduleEndpoint import ScheduleEndpoint + from pyhuntress.endpoints.managedsat.ServiceEndpoint import ServiceEndpoint + from pyhuntress.endpoints.managedsat.SystemEndpoint import SystemEndpoint + from pyhuntress.endpoints.managedsat.TimeEndpoint import TimeEndpoint + + +class ManagedSATCodebaseError(Exception): + def __init__(self) -> None: + super().__init__("Could not retrieve codebase from API.") + + +class HuntressSATAPIClient(HuntressClient): + """ + Huntress Managed SAT API client. Handles the connection to the Huntress Managed SAT API + and the configuration of all the available endpoints. + """ + + def __init__( + self, + managedsat_url: str, + public_key: str, + private_key: str, + ) -> None: + """ + Initializes the client with the given credentials and optionally a specific codebase. + If no codebase is given, it tries to get it from the API. + + Parameters: + managedsat_url (str): URL of the Huntress Managed SAT instance. + public_key (str): Your Huntress Managed SAT API Public key. + private_key (str): Your Huntress Managed SAT API Private key. + """ + self.managedsat_url: str = managedsat_url + self.public_key: str = public_key + self.private_key: str = private_key + + # Initializing endpoints + @property + def company(self) -> "CompanyEndpoint": + from pyhuntress.endpoints.managedsat.CompanyEndpoint import CompanyEndpoint + + return CompanyEndpoint(self) + + @property + def configurations(self) -> "ConfigurationsEndpoint": + from pyhuntress.endpoints.managedsat.ConfigurationsEndpoint import ConfigurationsEndpoint + + return ConfigurationsEndpoint(self) + + @property + def expense(self) -> "ExpenseEndpoint": + from pyhuntress.endpoints.managedsat.ExpenseEndpoint import ExpenseEndpoint + + return ExpenseEndpoint(self) + + @property + def finance(self) -> "FinanceEndpoint": + from pyhuntress.endpoints.managedsat.FinanceEndpoint import FinanceEndpoint + + return FinanceEndpoint(self) + + @property + def marketing(self) -> "MarketingEndpoint": + from pyhuntress.endpoints.managedsat.MarketingEndpoint import MarketingEndpoint + + return MarketingEndpoint(self) + + @property + def procurement(self) -> "ProcurementEndpoint": + from pyhuntress.endpoints.managedsat.ProcurementEndpoint import ProcurementEndpoint + + return ProcurementEndpoint(self) + + @property + def project(self) -> "ProjectEndpoint": + from pyhuntress.endpoints.managedsat.ProjectEndpoint import ProjectEndpoint + + return ProjectEndpoint(self) + + @property + def sales(self) -> "SalesEndpoint": + from pyhuntress.endpoints.managedsat.SalesEndpoint import SalesEndpoint + + return SalesEndpoint(self) + + @property + def schedule(self) -> "ScheduleEndpoint": + from pyhuntress.endpoints.managedsat.ScheduleEndpoint import ScheduleEndpoint + + return ScheduleEndpoint(self) + + @property + def service(self) -> "ServiceEndpoint": + from pyhuntress.endpoints.managedsat.ServiceEndpoint import ServiceEndpoint + + return ServiceEndpoint(self) + + @property + def system(self) -> "SystemEndpoint": + from pyhuntress.endpoints.managedsat.SystemEndpoint import SystemEndpoint + + return SystemEndpoint(self) + + @property + def time(self) -> "TimeEndpoint": + from pyhuntress.endpoints.managedsat.TimeEndpoint import TimeEndpoint + + return TimeEndpoint(self) + + def _get_url(self) -> str: + """ + Generates and returns the URL for the Huntress Managed SAT API endpoints based on the company url and codebase. + + Returns: + str: API URL. + """ + return f"https://{self.managedsat_url}/{self.codebase.strip('/')}/apis/3.0" + + def _try_get_codebase_from_api(self, managedsat_url: str, company_name: str, headers: dict[str, str]) -> str: + """ + Tries to retrieve the codebase from the API using the provided company url, company name and headers. + + Parameters: + company_url (str): URL of the company. + company_name (str): Name of the company. + headers (dict[str, str]): Headers to be sent in the request. + + Returns: + str: Codebase string or None if an error occurs. + """ + url = f"https://{managedsat_url}/login/companyinfo/{company_name}" + response = self._make_request("GET", url, headers=headers) + return response.json().get("Codebase") + + def _get_auth_string(self) -> str: + """ + Creates and returns the base64 encoded authorization string required for API requests. + + Returns: + str: Base64 encoded authorization string. + """ + return "Basic " + base64.b64encode( + bytes( + f"{self.company_name}+{self.public_key}:{self.private_key}", + encoding="utf8", + ) + ).decode("ascii") + + def _get_headers(self) -> dict[str, str]: + """ + Generates and returns the headers required for making API requests. + + Returns: + dict[str, str]: Dictionary of headers including Content-Type, Client ID, and Authorization. + """ + return { + "Content-Type": "application/json", + "clientId": self.client_id, + "Authorization": self._get_auth_string(), + } diff --git a/src/pyhuntress/clients/siem_client.py b/src/pyhuntress/clients/siem_client.py new file mode 100644 index 0000000..e82232b --- /dev/null +++ b/src/pyhuntress/clients/siem_client.py @@ -0,0 +1,127 @@ +import typing +from datetime import datetime, timezone +import base64 + +from pyhuntress.clients.huntress_client import HuntressClient +from pyhuntress.config import Config + +if typing.TYPE_CHECKING: + from pyhuntress.endpoints.siem.AccountEndpoint import AccountEndpoint + from pyhuntress.endpoints.siem.ActorEndpoint import ActorEndpoint + from pyhuntress.endpoints.siem.AgentsEndpoint import AgentsEndpoint + from pyhuntress.endpoints.siem.BillingreportsEndpoint import BillingreportsEndpoint + from pyhuntress.endpoints.siem.IncidentreportsEndpoint import IncidentreportsEndpoint + from pyhuntress.endpoints.siem.OrganizationsEndpoint import OrganizationsEndpoint + from pyhuntress.endpoints.siem.ReportsEndpoint import ReportsEndpoint + from pyhuntress.endpoints.siem.SignalsEndpoint import SignalsEndpoint + + +class HuntressSIEMAPIClient(HuntressClient): + """ + Huntress SIEM API client. Handles the connection to the Huntress SIEM API + and the configuration of all the available endpoints. + """ + + def __init__( + self, + siem_url: str, + publickey: str, + privatekey: str, + ) -> None: + """ + Initializes the client with the given credentials. + + Parameters: + siem_url (str): URL of your Huntress SIEM instance. + publickey (str): Your Huntress SIEM API public key. + privatekey (str): Your Huntress SIEM API private key. + """ + self.siem_url: str = siem_url + self.publickey: str = publickey + self.privatekey: str = privatekey + self.token_expiry_time: datetime = datetime.now(tz=timezone.utc) + + # Grab first access token + self.base64_auth: str = self._get_auth_key() + + # Initializing endpoints + @property + def account(self) -> "AccountEndpoint": + from pyhuntress.endpoints.siem.AccountEndpoint import AccountEndpoint + + return AccountEndpoint(self) + + @property + def actor(self) -> "ActorEndpoint": + from pyhuntress.endpoints.siem.ActorEndpoint import ActorEndpoint + + return ActorEndpoint(self) + + @property + def agents(self) -> "AgentsEndpoint": + from pyhuntress.endpoints.siem.AgentsEndpoint import AgentsEndpoint + + return AgentsEndpoint(self) + + @property + def billing_reports(self) -> "BillingreportsEndpoint": + from pyhuntress.endpoints.siem.BillingreportsEndpoint import BillingreportsEndpoint + + return BillingreportsEndpoint(self) + + @property + def incident_reports(self) -> "IncidentreportsEndpoint": + from pyhuntress.endpoints.siem.IncidentreportsEndpoint import IncidentreportsEndpoint + + return IncidentreportsEndpoint(self) + + @property + def organizations(self) -> "OrganizationsEndpoint": + from pyhuntress.endpoints.siem.OrganizationsEndpoint import OrganizationsEndpoint + + return OrganizationsEndpoint(self) + + @property + def reports(self) -> "ReportsEndpoint": + from pyhuntress.endpoints.siem.ReportsEndpoint import ReportsEndpoint + + return ReportsEndpoint(self) + + @property + def signals(self) -> "SignalsEndpoint": + from pyhuntress.endpoints.siem.SignalsEndpoint import SignalsEndpoint + + return SignalsEndpoint(self) + + def _get_url(self) -> str: + """ + Generates and returns the URL for the Huntress SIEM API endpoints based on the company url and codebase. + Logs in an obtains an access token. + Returns: + str: API URL. + """ + return f"https://{self.siem_url}/v1" + + def _get_auth_key(self) -> str: + """ + Creates a base64 encoded authentication string to the Huntress SIEM API to obtain an access token. + """ + # Format: base64encode(api_key:api_secret) + + auth_str = f"{self.publickey}:{self.privatekey}" + auth_bytes = auth_str.encode('ascii') + base64_auth = base64.b64encode(auth_bytes).decode('ascii') + + return base64_auth + + def _get_headers(self) -> dict[str, str]: + """ + Generates and returns the headers required for making API requests. The access token is refreshed if necessary before returning. + + Returns: + dict[str, str]: Dictionary of headers including Content-Type, Client ID, and Authorization. + """ + return { + "Content-Type": "application/json", + "Authorization": f"Basic {self.base64_auth}", + } diff --git a/src/pyhuntress/config.py b/src/pyhuntress/config.py new file mode 100644 index 0000000..bd1f4d7 --- /dev/null +++ b/src/pyhuntress/config.py @@ -0,0 +1,9 @@ +class Config: + def __init__(self, max_retries=3) -> None: # noqa: ANN001 + """ + Initializes a new instance of the Config class. + + Args: + max_retries (int): The maximum number of retries for a retryable HTTP operation (500) (default = 3) + """ + self.max_retries = max_retries diff --git a/src/pyhuntress/endpoints/__init__.py b/src/pyhuntress/endpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/endpoints/base/__init__.py b/src/pyhuntress/endpoints/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/endpoints/base/huntress_endpoint.py b/src/pyhuntress/endpoints/base/huntress_endpoint.py new file mode 100644 index 0000000..8e0310e --- /dev/null +++ b/src/pyhuntress/endpoints/base/huntress_endpoint.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from pydantic import BaseModel + from requests import Response + + from pyhuntress.clients.huntress_client import HuntressClient + from pyhuntress.types import ( + RequestData, + RequestMethod, + RequestParams, + ) + +TChildEndpoint = TypeVar("TChildEndpoint", bound="HuntressEndpoint") +TModel = TypeVar("TModel", bound="BaseModel") + + +class HuntressEndpoint: + """ + HuntressEndpoint is a base class for all Huntress API endpoint classes. + It provides a generic implementation for interacting with the Huntress API, + handling requests, parsing responses into model instances, and managing pagination. + + HuntressEndpoint makes use of a generic type variable TModel, which represents + the expected HuntressModel type for the endpoint. This allows for type-safe + handling of model instances throughout the class. + + Each derived class should specify the HuntressModel type it will be working with + when inheriting from HuntressEndpoint. For example: + class CompanyEndpoint(HuntressEndpoint[CompanyModel]). + + HuntressEndpoint provides methods for making API requests and handles pagination + using the PaginatedResponse class. By default, most CRUD methods raise a + NotImplementedError, which should be overridden in derived classes to provide + endpoint-specific implementations. + + HuntressEndpoint also supports handling nested endpoints, which are referred to as + child endpoints. Child endpoints can be registered and accessed through their parent + endpoint, allowing for easy navigation through related resources in the API. + + Args: + client: The HuntressAPIClient instance. + endpoint_url (str): The base URL for the specific endpoint. + parent_endpoint (HuntressEndpoint, optional): The parent endpoint, if applicable. + + Attributes: + client (HuntressAPIClient): The HuntressAPIClient instance. + endpoint_url (str): The base URL for the specific endpoint. + _parent_endpoint (HuntressEndpoint): The parent endpoint, if applicable. + model_parser (ModelParser): An instance of the ModelParser class used for parsing API responses. + _model (Type[TModel]): The model class for the endpoint. + _id (int): The ID of the current resource, if applicable. + _child_endpoints (List[HuntressEndpoint]): A list of registered child endpoints. + + Generic Type: + TModel: The model class for the endpoint. + """ + + def __init__( + self, + client: HuntressClient, + endpoint_url: str, + parent_endpoint: HuntressEndpoint | None = None, + ) -> None: + """ + Initialize a HuntressEndpoint instance with the client and endpoint base. + + Args: + client: The HuntressAPIClient instance. + endpoint_base (str): The base URL for the specific endpoint. + """ + self.client = client + self.endpoint_base = endpoint_url + self._parent_endpoint = parent_endpoint + self._id = None + self._child_endpoints: list[HuntressEndpoint] = [] + + def _register_child_endpoint(self, child_endpoint: TChildEndpoint) -> TChildEndpoint: + """ + Register a child endpoint to the current endpoint. + + Args: + child_endpoint (HuntressEndpoint): The child endpoint instance. + + Returns: + HuntressEndpoint: The registered child endpoint. + """ + self._child_endpoints.append(child_endpoint) + return child_endpoint + + def _url_join(self, *args) -> str: # noqa: ANN002 + """ + Join URL parts into a single URL string. + + Args: + *args: The URL parts to join. + + Returns: + str: The joined URL string. + """ + url_parts = [str(arg).strip("/") for arg in args] + return "/".join(url_parts) + + def _get_replaced_url(self) -> str: + if self._id is None: + return self.endpoint_base + return self.endpoint_base.replace("{id}", str(self._id)) + + def _make_request( + self, + method: RequestMethod, + endpoint: HuntressEndpoint | None = None, + data: RequestData | None = None, + params: RequestParams | None = None, + headers: dict[str, str] | None = None, + stream: bool = False, # noqa: FBT001, FBT002 + ) -> Response: + """ + Make an API request using the specified method, endpoint, data, and parameters. + This function isn't intended for use outside of this class. + Please use the available CRUD methods as intended. + + Args: + method (str): The HTTP method to use for the request (e.g., GET, POST, PUT, etc.). + endpoint (str, optional): The endpoint to make the request to. + data (dict, optional): The request data to send. + params (dict, optional): The query parameters to include in the request. + + Returns: + The Response object (see requests.Response). + + Raises: + Exception: If the request returns a status code >= 400. + """ + url = self._get_endpoint_url() + if endpoint: + url = self._url_join(url, endpoint) + + return self.client._make_request(method, url, data, params, headers, stream) + + def _build_url(self, other_endpoint: HuntressEndpoint) -> str: + if other_endpoint._parent_endpoint is not None: + parent_url = self._build_url(other_endpoint._parent_endpoint) + if other_endpoint._parent_endpoint._id is not None: + return self._url_join( + parent_url, + other_endpoint._get_replaced_url(), + ) + else: # noqa: RET505 + return self._url_join(parent_url, other_endpoint._get_replaced_url()) + else: + return self._url_join(self.client._get_url(), other_endpoint._get_replaced_url()) + + def _get_endpoint_url(self) -> str: + return self._build_url(self) + + def _parse_many(self, model_type: type[TModel], data: list[dict[str, Any]]) -> list[TModel]: + return [model_type.model_validate(d) for d in data] + + def _parse_one(self, model_type: type[TModel], data: dict[str, Any]) -> TModel: + return model_type.model_validate(data) diff --git a/src/pyhuntress/endpoints/managedsat/__init__.py b/src/pyhuntress/endpoints/managedsat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/endpoints/siem/AccountEndpoint.py b/src/pyhuntress/endpoints/siem/AccountEndpoint.py new file mode 100644 index 0000000..4719717 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/AccountEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMAccount +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class AccountEndpoint( + HuntressEndpoint, + IGettable[SIEMAccount, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "account", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMAccount) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMAccount: + """ + Performs a GET request against the /account endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_one( + SIEMAccount, + super()._make_request("GET", data=data, params=params).json().get('account', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/ActorEndpoint.py b/src/pyhuntress/endpoints/siem/ActorEndpoint.py new file mode 100644 index 0000000..6984533 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/ActorEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMActorResponse +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class ActorEndpoint( + HuntressEndpoint, + IGettable[SIEMActorResponse, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "actor", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMActorResponse) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMActorResponse: + """ + Performs a GET request against the /Actor endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_one( + SIEMActorResponse, + super()._make_request("GET", data=data, params=params).json(), + ) diff --git a/src/pyhuntress/endpoints/siem/AgentsEndpoint.py b/src/pyhuntress/endpoints/siem/AgentsEndpoint.py new file mode 100644 index 0000000..da7fcab --- /dev/null +++ b/src/pyhuntress/endpoints/siem/AgentsEndpoint.py @@ -0,0 +1,84 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, + IPaginateable, +) +from pyhuntress.models.siem import SIEMAgentsResponse, SIEMAgents +from pyhuntress.responses.paginated_response import PaginatedResponse +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class AgentsEndpoint( + HuntressEndpoint, + IGettable[SIEMAgents, HuntressSIEMRequestParams], + IPaginateable[SIEMAgents, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "agents", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMAgentsResponse) + IPaginateable.__init__(self, SIEMAgentsResponse) + + def id(self, id: int) -> HuntressEndpoint: + """ + Sets the ID for this endpoint and returns an initialized HuntressEndpoint object to move down the chain. + + Parameters: + id (int): The ID to set. + Returns: + HuntressEndpoint: The initialized HuntressEndpoint object. + """ + child = HuntressEndpoint(self.client, parent_endpoint=self) + child._id = id + return child + + def paginated( + self, + page: int, + limit: int, + params: HuntressSIEMRequestParams | None = None, + ) -> PaginatedResponse[SIEMAgents]: + """ + Performs a GET request against the /agents endpoint and returns an initialized PaginatedResponse object. + + Parameters: + page (int): The page number to request. + limit (int): The number of results to return per page. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + PaginatedResponse[SIEMAgentsResponse]: The initialized PaginatedResponse object. + """ + if params: + params["page"] = page + params["pageSize"] = limit + else: + params = {"page": page, "pageSize": limit} + return PaginatedResponse( + super()._make_request("GET", params=params), + SIEMAgents, + self, + page, + limit, + params, + ) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMAgents: + """ + Performs a GET request against the /agents endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMAgents, + super()._make_request("GET", data=data, params=params).json().get('agents', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/BillingreportsEndpoint.py b/src/pyhuntress/endpoints/siem/BillingreportsEndpoint.py new file mode 100644 index 0000000..0003d97 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/BillingreportsEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMBillingReports +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class BillingreportsEndpoint( + HuntressEndpoint, + IGettable[SIEMBillingReports, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "billing_reports", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMBillingReports) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMBillingReports: + """ + Performs a GET request against the /Billing_reports endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMBillingReports, + super()._make_request("GET", data=data, params=params).json().get('billing_reports', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/IncidentreportsEndpoint.py b/src/pyhuntress/endpoints/siem/IncidentreportsEndpoint.py new file mode 100644 index 0000000..13cfb1b --- /dev/null +++ b/src/pyhuntress/endpoints/siem/IncidentreportsEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMIncidentReports +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class IncidentreportsEndpoint( + HuntressEndpoint, + IGettable[SIEMIncidentReports, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "incident_reports", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMIncidentReports) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMIncidentReports: + """ + Performs a GET request against the /Incident_reports endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMIncidentReports, + super()._make_request("GET", data=data, params=params).json().get('incident_reports', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/OrganizationsEndpoint.py b/src/pyhuntress/endpoints/siem/OrganizationsEndpoint.py new file mode 100644 index 0000000..6eb75e6 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/OrganizationsEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMOrganizations +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class OrganizationsEndpoint( + HuntressEndpoint, + IGettable[SIEMOrganizations, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "organizations", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMOrganizations) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMOrganizations: + """ + Performs a GET request against the /Organizations endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMOrganizations, + super()._make_request("GET", data=data, params=params).json().get('organizations', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/ReportsEndpoint.py b/src/pyhuntress/endpoints/siem/ReportsEndpoint.py new file mode 100644 index 0000000..350d2b0 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/ReportsEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMReports +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class ReportsEndpoint( + HuntressEndpoint, + IGettable[SIEMReports, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "reports", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMReports) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMReports: + """ + Performs a GET request against the /Reports endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMReports, + super()._make_request("GET", data=data, params=params).json().get('reports', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/SignalsEndpoint.py b/src/pyhuntress/endpoints/siem/SignalsEndpoint.py new file mode 100644 index 0000000..dfc6b71 --- /dev/null +++ b/src/pyhuntress/endpoints/siem/SignalsEndpoint.py @@ -0,0 +1,37 @@ +from pyhuntress.endpoints.base.huntress_endpoint import HuntressEndpoint +from pyhuntress.interfaces import ( + IGettable, +) +from pyhuntress.models.siem import SIEMSignals +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, +) + + +class SignalsEndpoint( + HuntressEndpoint, + IGettable[SIEMSignals, HuntressSIEMRequestParams], +): + def __init__(self, client, parent_endpoint=None) -> None: + HuntressEndpoint.__init__(self, client, "signals", parent_endpoint=parent_endpoint) + IGettable.__init__(self, SIEMSignals) + + def get( + self, + data: JSON | None = None, + params: HuntressSIEMRequestParams | None = None, + ) -> SIEMSignals: + """ + Performs a GET request against the /Signals endpoint. + + Parameters: + data (dict[str, Any]): The data to send in the request body. + params (dict[str, int | str]): The parameters to send in the request query string. + Returns: + SIEMAuthInformation: The parsed response data. + """ + return self._parse_many( + SIEMSignals, + super()._make_request("GET", data=data, params=params).json().get('signals', {}), + ) diff --git a/src/pyhuntress/endpoints/siem/__init__.py b/src/pyhuntress/endpoints/siem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/exceptions.py b/src/pyhuntress/exceptions.py new file mode 100644 index 0000000..a0d5836 --- /dev/null +++ b/src/pyhuntress/exceptions.py @@ -0,0 +1,85 @@ +import json +from typing import ClassVar +from urllib.parse import urlsplit, urlunsplit + +from requests import JSONDecodeError, Response + + +class HuntressException(Exception): # noqa: N818 + _code_explanation: ClassVar[str] = "" # Ex: for 404 "Not Found" + _error_suggestion: ClassVar[str] = "" # Ex: for 404 "Check the URL you are using is correct" + + def __init__(self, req_response: Response, *, extra_message: str = "") -> None: + self.response = req_response + self.extra_message = extra_message + super().__init__(self.message()) + + def _get_sanitized_url(self) -> str: + """ + Simplify URL down to method, hostname, and path. + """ + url_components = urlsplit(self.response.url) + return urlunsplit( + ( + url_components.scheme, + url_components.hostname, + url_components.path, + "", + "", + ) + ) + + def details(self) -> str: + try: + # If response was json, then format it nicely + return json.dumps(self.response.json(), indent=4) + except JSONDecodeError: + return self.response.text + + def message(self) -> str: + return ( + f"A HTTP {self.response.status_code} ({self._code_explanation}) error has occurred while requesting" + f" {self._get_sanitized_url()}.\n{self.response.reason}\n{self._error_suggestion}\n{self.extra_message}" + ).strip() # Remove extra whitespace (Ex: if extra_message == "") + + +class MalformedRequestException(HuntressException): + _code_explanation = "Bad Request" + _error_suggestion = ( + "The request could not be understood by the server due to malformed syntax. Please check modify your request" + " before retrying." + ) + + +class AuthenticationFailedException(HuntressException): + _code_explanation = "Unauthorized" + _error_suggestion = "Please check your credentials are correct before retrying." + + +class PermissionsFailedException(HuntressException): + _code_explanation = "Forbidden" + _error_suggestion = "You may be attempting to access a resource you do not have the appropriate permissions for." + + +class NotFoundException(HuntressException): + _code_explanation = "Not Found" + _error_suggestion = "You may be attempting to access a resource that has been moved or deleted." + + +class MethodNotAllowedException(HuntressException): + _code_explanation = "Method Not Allowed" + _error_suggestion = "This resource does not support the HTTP method you are trying to use." + + +class ConflictException(HuntressException): + _code_explanation = "Conflict" + _error_suggestion = "This resource is possibly in use or conflicts with another record." + + +class ServerError(HuntressException): + _code_explanation = "Internal Server Error" + + +class ObjectExistsError(HuntressException): + _code_explanation = "Object Exists" + _error_suggestion = "This resource already exists." diff --git a/src/pyhuntress/interfaces.py b/src/pyhuntress/interfaces.py new file mode 100644 index 0000000..9dfbf54 --- /dev/null +++ b/src/pyhuntress/interfaces.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + +from pyhuntress.responses.paginated_response import PaginatedResponse +from pyhuntress.types import ( + JSON, + HuntressSIEMRequestParams, + HuntressSATRequestParams, + PatchRequestData, +) + +if TYPE_CHECKING: + from pydantic import BaseModel + +TModel = TypeVar("TModel", bound="BaseModel") +TRequestParams = TypeVar( + "TRequestParams", + bound=HuntressSATRequestParams | HuntressSIEMRequestParams, +) + + +class IMethodBase(ABC, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + self.model = model + + +class IPaginateable(IMethodBase, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def paginated( + self, + page: int, + page_size: int, + params: TRequestParams | None = None, + ) -> PaginatedResponse[TModel]: + pass + + +class IGettable(IMethodBase, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def get( + self, + data: JSON | None = None, + params: TRequestParams | None = None, + ) -> TModel: + pass + + +class IPostable(IMethodBase, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def post( + self, + data: JSON | None = None, + params: TRequestParams | None = None, + ) -> TModel: + pass + + +class IPatchable(IMethodBase, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def patch( + self, + data: PatchRequestData, + params: TRequestParams | None = None, + ) -> TModel: + pass + + +class IPuttable(IMethodBase, Generic[TModel, TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def put( + self, + data: JSON | None = None, + params: TRequestParams | None = None, + ) -> TModel: + pass + + +class IDeleteable(IMethodBase, Generic[TRequestParams]): + def __init__(self, model: TModel) -> None: + super().__init__(model) + + @abstractmethod + def delete( + self, + data: JSON | None = None, + params: TRequestParams | None = None, + ) -> None: + pass diff --git a/src/pyhuntress/models/__init__.py b/src/pyhuntress/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/models/base/__init__.py b/src/pyhuntress/models/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/models/base/huntress_model.py b/src/pyhuntress/models/base/huntress_model.py new file mode 100644 index 0000000..4a335ae --- /dev/null +++ b/src/pyhuntress/models/base/huntress_model.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import inspect +from types import UnionType +from typing import Union, get_args, get_origin + +from pydantic import BaseModel, ConfigDict + +from pyhuntress.utils.naming import to_camel_case + + +class HuntressModel(BaseModel): + model_config = ConfigDict( + alias_generator=to_camel_case, + populate_by_name=True, + use_enum_values=True, + protected_namespaces=(), + ) + + @classmethod + def _get_field_names(cls) -> list[str]: + field_names = [] + for v in cls.__fields__.values(): + was_model = False + for arg in get_args(v.annotation): + if inspect.isclass(arg) and issubclass(arg, HuntressModel): + was_model = True + field_names.extend([f"{v.alias}/{sub}" for sub in arg._get_field_names()]) + + if not was_model: + field_names.append(v.alias) + + return field_names + + @classmethod + def _get_field_names_and_types(cls) -> dict[str, str]: # noqa: C901 + field_names_and_types = {} + for v in cls.__fields__.values(): + was_model = False + field_type = "None" + if get_origin(v.annotation) is UnionType or get_origin(v.annotation) is Union: + for arg in get_args(v.annotation): + if inspect.isclass(arg) and issubclass(arg, HuntressModel): + was_model = True + for sk, sv in arg._get_field_names_and_types().items(): + field_names_and_types[f"{v.alias}/{sk}"] = sv + elif arg is not None and arg.__name__ != "NoneType": + field_type = arg.__name__ + else: + if inspect.isclass(v.annotation) and issubclass(v.annotation, HuntressModel): + was_model = True + for sk, sv in v.annotation._get_field_names_and_types().items(): + field_names_and_types[f"{v.alias}/{sk}"] = sv + elif v.annotation is not None and v.annotation.__name__ != "NoneType": + field_type = v.annotation.__name__ + + if not was_model: + field_names_and_types[v.alias] = field_type + + return field_names_and_types diff --git a/src/pyhuntress/models/base/message_model.py b/src/pyhuntress/models/base/message_model.py new file mode 100644 index 0000000..5bd7a02 --- /dev/null +++ b/src/pyhuntress/models/base/message_model.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class GenericMessageModel(BaseModel): + message: str diff --git a/src/pyhuntress/models/managedsat/__init__.py b/src/pyhuntress/models/managedsat/__init__.py new file mode 100644 index 0000000..3f27158 --- /dev/null +++ b/src/pyhuntress/models/managedsat/__init__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import Annotated, Any, Literal +from uuid import UUID +from pydantic import Field + +from pyhuntress.models.base.huntress_model import HuntressModel + + +#class AccountingBatch(HuntressModel): +# info: Annotated[dict[str, str] | None, Field(alias="_info")] = None +# batch_identifier: Annotated[str | None, Field(alias="batchIdentifier")] = None +# closed_flag: Annotated[bool | None, Field(alias="closedFlag")] = None +# export_expenses_flag: Annotated[bool | None, Field(alias="exportExpensesFlag")] = None +# export_invoices_flag: Annotated[bool | None, Field(alias="exportInvoicesFlag")] = None +# export_products_flag: Annotated[bool | None, Field(alias="exportProductsFlag")] = None +# id: int | None = None diff --git a/src/pyhuntress/models/siem/__init__.py b/src/pyhuntress/models/siem/__init__.py new file mode 100644 index 0000000..1e0e775 --- /dev/null +++ b/src/pyhuntress/models/siem/__init__.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Literal +from uuid import UUID + +from pydantic import Field + +from pyhuntress.models.base.huntress_model import HuntressModel + +class SIEMPagination(HuntressModel): + current_page: int | None = Field(default=None, alias="CurrentPage") + current_page_count: int | None = Field(default=None, alias="CurrentPageCount") + limit: int | None = Field(default=None, alias="Limit") + total_count: int | None = Field(default=None, alias="TotalCount") + next_page: int | None = Field(default=None, alias="NextPage") + next_page_url: str | None = Field(default=None, alias="NextPageURL") + next_page_token: str | None = Field(default=None, alias="NextPageToken") + +class SIEMAgents(HuntressModel): + id: int | None = Field(default=None, alias="Id") + version: str | None = Field(default=None, alias="Version") + arch: str | None = Field(default=None, alias="Arch") + win_build_number: str | None = Field(default=None, alias="WinBuildNumber") + domain_name: str | None = Field(default=None, alias="DomainName") + created_at: datetime | None = Field(default=None, alias="CreateAt") + hostname: str | None = Field(default=None, alias="Hostname") + ipv4_address: str | None = Field(default=None, alias="IPv4Address") + external_ip: str | None = Field(default=None, alias="ExternalIP") + mac_addresses: list | None = Field(default=None, alias="MacAddresses") + updated_at: datetime | None = Field(default=None, alias="IPv4Address") + last_survey_at: datetime | None = Field(default=None, alias="LastSurveyAt") + last_callback_at: datetime | None = Field(default=None, alias="LastCallbackAt") + account_id: int | None = Field(default=None, alias="AccountID") + organization_id: int | None = Field(default=None, alias="OrganizationID") + platform: Literal[ + "windows", + "darwin", + "linux", + ] | None = Field(default=None, alias="Platform") + os: str | None = Field(default=None, alias="OS") + service_pack_major: int | None = Field(default=None, alias="ServicePackMajor") + service_pack_minor: int | None = Field(default=None, alias="ServicePackMinor") + tags: list | None = Field(default=None, alias="Tags") + os_major: int | None = Field(default=None, alias="OSMajor") + os_minor: int | None = Field(default=None, alias="OSMinor") + os_patch: int | None = Field(default=None, alias="OSPatch") + version_number: int | None = Field(default=None, alias="VersionNumber") + edr_version: str | None = Field(default=None, alias="EDRVersion") + os_build_version: str | None = Field(default=None, alias="OSBuildVersion") + serial_number: str | None = Field(default=None, alias="SerialNumber") + defender_status: str | None = Field(default=None, alias="DefenderStatus") + defender_substatus: str | None = Field(default=None, alias="DefenderSubstatus") + defender_policy_status: str | None = Field(default=None, alias="DefenderPolicyStatus") + firewall_status: str | None = Field(default=None, alias="FirewallStatus") + +class SIEMAgentsResponse(HuntressModel): + agents: dict[str, Any] | None = Field(default=None, alias="Agents") + pagination: dict[str, Any] | None = Field(default=None, alias="Pagination") + +class SIEMAccount(HuntressModel): + id: int | None = Field(default=None, alias="Id") + name: str | None = Field(default=None, alias="Name") + subdomain: str | None = Field(default=None, alias="Subdomain") + status: str | None = Field(default=None, alias="Status") + +class SIEMActorResponse(HuntressModel): + account: dict[str, Any] | None = Field(default=None, alias="Account") + user: str | None = Field(default=None, alias="User") + +class SIEMBillingReports(HuntressModel): + id: int | None = Field(default=None, alias="Id") + plan: str | None = Field(default=None, alias="Plan") + quantity: int | None = Field(default=None, alias="Quantity") + amount: int | None = Field(default=None, alias="Amount") + currency_type: str | None = Field(default=None, alias="CurrencyType") + receipt: str | None = Field(default=None, alias="Receipt") + status: Literal[ + "open", + "paid", + "failed", + "partial_refund", + "full_refund", + "draft", + "voided", + ] | None = Field(default=None, alias="Status") + created_at: datetime | None = Field(default=None, alias="CreatedAt") + updated_at: datetime | None = Field(default=None, alias="UpdatedAt") + +class SIEMBillingReportsResponse(HuntressModel): + billing_reports: dict[str, Any] | None = Field(default=None, alias="BillingReports") + +class SIEMIncidentReportsResponse(HuntressModel): + incident_reports: dict[str, Any] | None = Field(default=None, alias="IncidentReports") + pagination: dict[str, Any] | None = Field(default=None, alias="Pagination") + +class SIEMIncidentReports(HuntressModel): + id: int | None = Field(default=None, alias="Id") + status: Literal[ + "sent", + "closed", + "dismissed", + "auto_remediating", + "deleting", + ] | None = Field(default=None, alias="Status") + summary: str | None = Field(default=None, alias="Summary") + body: str | None = Field(default=None, alias="Body") + updated_at: datetime | None = Field(default=None, alias="UpdatedAt") + agent_id: int | None = Field(default=None, alias="AgentId") + platform: Literal[ + "windows", + "darwin", + "microsoft_365", + "google", + "linux", + "other", + ] | None = Field(default=None, alias="Platform") + status_updated_at: datetime | None = Field(default=None, alias="StatusUpdatedAt") + organization_id: int | None = Field(default=None, alias="OrganizationId") + sent_at: datetime | None = Field(default=None, alias="SentAt") + account_id: int | None = Field(default=None, alias="AccountId") + subject: str | None = Field(default=None, alias="Subject") + remediations: list[dict[str, Any]] | None = Field(default=None, alias="Remediations") + severity: Literal[ + "low", + "high", + "critical", + ] | None = Field(default=None, alias="Severity") + closed_at: datetime | None = Field(default=None, alias="ClosedAt") + indicator_types: list | None = Field(default=None, alias="IndicatorTypes") + indicator_counts: dict[str, Any] | None = Field(default=None, alias="IndicatorCounts") + + +class SIEMRemediations(HuntressModel): + id: int | None = Field(default=None, alias="Id") + type: str | None = Field(default=None, alias="Type") + status: str | None = Field(default=None, alias="Status") + details: dict[str, Any] | None = Field(default=None, alias="Details") + completable_by_task_response: bool | None = Field(default=None, alias="CompletedByTaskResponse") + completable_manually: bool | None = Field(default=None, alias="CompletedManually") + display_action: str | None = Field(default=None, alias="DisplayAction") + approved_at: datetime | None = Field(default=None, alias="ApprovedAt") + approved_by: dict[str, Any] | None = Field(default=None, alias="ApprovedBy") + completed_at: datetime | None = Field(default=None, alias="CompletedAt") + +class SIEMRemediationsDetails(HuntressModel): + rule_id: int | None = Field(default=None, alias="RuleId") + rule_name: str | None = Field(default=None, alias="RuleName") + completed_at: datetime | None = Field(default=None, alias="CompletedAt") + forward_from: str | None = Field(default=None, alias="ForwardFrom") + remediation: str | None = Field(default=None, alias="remediation") + +class SIEMRemediationsApprovedBy(HuntressModel): + id: int | None = Field(default=None, alias="Id") + email: str | None = Field(default=None, alias="Email") + first_name: str | None = Field(default=None, alias="FirstName") + last_name: str | None = Field(default=None, alias="LastName") + +class SIEMIndicatorCounts(HuntressModel): + footholds: int | None = Field(default=None, alias="Footholds") + mde_detections: int | None = Field(default=None, alias="MDEDetections") + monitored_files: int | None = Field(default=None, alias="MonitoredFiles") + siem_detections: int | None = Field(default=None, alias="SIEMDetections") + managed_identity: int | None = Field(default=None, alias="ManagedIdentity") + process_detections: int | None = Field(default=None, alias="ProcessDetections") + ransomware_canaries: int | None = Field(default=None, alias="RansomwareCanaries") + antivirus_detections: int | None = Field(default=None, alias="AntivirusDetections") + +class SIEMOrganizationsResponse(HuntressModel): + organizations: dict[str, Any] | None = Field(default=None, alias="Organizations") + pagination: dict[str, Any] | None = Field(default=None, alias="Pagination") + +class SIEMOrganizations(HuntressModel): + id: int | None = Field(default=None, alias="Id") + name: str | None = Field(default=None, alias="Name") + created_at: datetime | None = Field(default=None, alias="CreatedAt") + updated_at: datetime | None = Field(default=None, alias="UpdatedAt") + account_id: int | None = Field(default=None, alias="AccountId") + key: str | None = Field(default=None, alias="Key") + notify_emails: list | None = Field(default=None, alias="NotifyEmails") + microsoft_365_tenant_id: str | None = Field(default=None, alias="Microsoft365TenantId") + incident_reports_count: int | None = Field(default=None, alias="IncidentsReportsCount") + agents_count: int | None = Field(default=None, alias="AgentsCount") + microsoft_365_users_count: int | None = Field(default=None, alias="Microsoft365UsersCount") + sat_learner_count: int | None = Field(default=None, alias="SATLearnerCount") + logs_sources_count: int | None = Field(default=None, alias="LogsSourcesCount") + +class SIEMReportsResponse(HuntressModel): + reports: dict[str, Any] | None = Field(default=None, alias="Organizations") + pagination: dict[str, Any] | None = Field(default=None, alias="Pagination") + +class SIEMReports(HuntressModel): + id: int | None = Field(default=None, alias="Id") + type: Literal[ + "monthly_summary", + "quarterly_summary", + "yearly_summary", + ] | None = Field(default=None, alias="Type") + period: str | None = Field(default=None, alias="Period") + organization_id: int | None = Field(default=None, alias="OrganizationId") + created_at: datetime | None = Field(default=None, alias="CreatedAt") + updated_at: datetime | None = Field(default=None, alias="UpdatedAt") + url: str | None = Field(default=None, alias="Type") + events_analyzed: int | None = Field(default=None, alias="EventsAnalyzed") + total_entities: int | None = Field(default=None, alias="TotalEntities") + signals_detected: int | None = Field(default=None, alias="SignalsDetected") + signals_investigated: int | None = Field(default=None, alias="SignalsInvestigated") + itdr_entities: int | None = Field(default=None, alias="ITDREntities") + itdr_events: int | None = Field(default=None, alias="ITDREvents") + siem_total_logs: int | None = Field(default=None, alias="SIEMTotalLogs") + siem_ingested_logs: int | None = Field(default=None, alias="SIEMIngestedLogs") + autorun_events: int | None = Field(default=None, alias="AutorunEvents") + autorun_signals_detected: int | None = Field(default=None, alias="AutorunSignalsDetected") + investigations_completed: int | None = Field(default=None, alias="InvestigationsCompleted") + autorun_signals_reviewed: int | None = Field(default=None, alias="AutorunSignalsReviewed") + incidents_reported: int | None = Field(default=None, alias="IncidentsReported") + itdr_incidents_reported: int | None = Field(default=None, alias="ITDRIncidentsReported") + siem_incidents_reported: int | None = Field(default=None, alias="SIEMIncidentsReported") + incidents_resolved: int | None = Field(default=None, alias="IncidentsResolved") + incident_severity_counts: int | None = Field(default=None, alias="IncidentSeverityCounts") + incident_product_counts: int | None = Field(default=None, alias="IncidentProductCounts") + incident_indicator_counts: int | None = Field(default=None, alias="IncidentIndicatorCounts") + top_incident_av_threats: list | None = Field(default=None, alias="TopIncidentAVThreats") + top_incident_hosts: list | None = Field(default=None, alias="TopIncidentHosts") + potential_threat_indicators: list | None = Field(default=None, alias="PotentialThreatIndicators") + agents_count: int | None = Field(default=None, alias="AgentsCount") + deployed_canaries_count: int | None = Field(default=None, alias="DeployedCanariesCount") + protected_profiles_count: int | None = Field(default=None, alias="ProtectedProfilesCount") + windows_agent_count: int | None = Field(default=None, alias="WindowsAgentCount") + macos_agent_count: int | None = Field(default=None, alias="MacOSAgentCount") + servers_agent_count: int | None = Field(default=None, alias="ServersAgentCount") + analyst_note: str | None = Field(default=None, alias="AnalystNote") + global_threats_note: str | None = Field(default=None, alias="GlobalThreatsNote") + ransomware_note: str | None = Field(default=None, alias="RansomwareNote") + incident_log: str | None = Field(default=None, alias="IncidentLog") + total_mav_detection_count: int | None = Field(default=None, alias="TotalMAVDetectionCount") + blocked_malware_count: int | None = Field(default=None, alias="BlockedMalwareCount") + investigated_mav_detection_count: int | None = Field(default=None, alias="InvestigatedMAVDetectionCount") + mav_incident_report_count: int | None = Field(default=None, alias="MAVIncidentReportCount") + autoruns_reviewed: int | None = Field(default=None, alias="AutorunsReviewed") + host_processes_analyzed: int | None = Field(default=None, alias="HostProcessesAnalyzed") + process_detections: int | None = Field(default=None, alias="ProcessDetections") + process_detections_reviewed: int | None = Field(default=None, alias="ProcessDetectionsReviewed") + process_detections_reported: int | None = Field(default=None, alias="ProcessDetectionsReported") + itdr_signals: int | None = Field(default=None, alias="ITDRSignals") + siem_signals: int | None = Field(default=None, alias="SIEMSignals") + itdr_investigations_completed: int | None = Field(default=None, alias="ITDRInvestigationsCompleted") + macos_agents: str | None = Field(default=None, alias="MacOSAgents") + windows_agents: str | None = Field(default=None, alias="WindowsAgents") + only_macos_agents: str | None = Field(default=None, alias="OnlyMacOSAgents") + antivirus_exclusions_count: int | None = Field(default=None, alias="AntivirusExclusionsCount") + new_exclusions_count: int | None = Field(default=None, alias="NewExclusionsCount") + allowed_exclusions_count: int | None = Field(default=None, alias="AllowedExclusionsCount") + risky_exclusions_removed_count: int | None = Field(default=None, alias="RiskyExclusionsRemovedCount") + +class SIEMSignalsResponse(HuntressModel): + signals: dict[str, Any] | None = Field(default=None, alias="Organizations") + pagination: dict[str, Any] | None = Field(default=None, alias="Pagination") + +class SIEMSignals(HuntressModel): + created_at: datetime | None = Field(default=None, alias="CreatedAt") + id: int | None = Field(default=None, alias="Id") + status: str | None = Field(default=None, alias="Status") + updated_at: datetime | None = Field(default=None, alias="UpdatedAt") + details: dict[str, Any] | None = Field(default=None, alias="Details") + entity: dict[str, Any] | None = Field(default=None, alias="Entity") + investigated_at: datetime | None = Field(default=None, alias="InvestigatedAt") + investigation_context: str | None = Field(default=None, alias="InvestigationContext") + name: str | None = Field(default=None, alias="Name") + organization: dict[str, Any] | None = Field(default=None, alias="Organization") + type: str | None = Field(default=None, alias="Type") + +class SIEMSignalsDetails(HuntressModel): + identity: str | None = Field(default=None, alias="Identity") + application: str | None = Field(default=None, alias="Application") + detected_at: datetime | None = Field(default=None, alias="DetectedAt") + +class SIEMSignalsEntity(HuntressModel): + id: int | None = Field(default=None, alias="Id") + name: str | None = Field(default=None, alias="Name") + type: Literal[ + "user_entity", + "source", + "mailbox", + "service_principal", + "agent", + "identity", + ] | None = Field(default=None, alias="Type") diff --git a/src/pyhuntress/py.typed b/src/pyhuntress/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/responses/__init__.py b/src/pyhuntress/responses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/responses/paginated_response.py b/src/pyhuntress/responses/paginated_response.py new file mode 100644 index 0000000..d6a3e01 --- /dev/null +++ b/src/pyhuntress/responses/paginated_response.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from pyhuntress.utils.helpers import parse_link_headers + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pydantic import BaseModel + from requests import Response + + from pyhuntress.types import RequestParams + + +TModel = TypeVar("TModel", bound="BaseModel") + +if TYPE_CHECKING: + from pyhuntress.interfaces import IPaginateable + + +class PaginatedResponse(Generic[TModel]): + """ + PaginatedResponse is a wrapper class for handling paginated responses from the + Huntress API. It provides methods for navigating through the pages of the response + and accessing the data contained within each page. + + The class is designed to work with HuntressEndpoint and its derived classes to + parse the API response into model instances. It also supports iteration, allowing + the user to loop through the items within the paginated response. + + PaginatedResponse uses a generic type variable TModel, which represents the + expected model type for the response data. This allows for type-safe handling + of model instances throughout the class. + """ + + def __init__( + self, + response: Response, + response_model: type[TModel], + endpoint: IPaginateable, + page: int, + page_size: int, + params: RequestParams | None = None, + ) -> None: + """ + PaginatedResponse is a wrapper class for handling paginated responses from the + Huntress API. It provides methods for navigating through the pages of the response + and accessing the data contained within each page. + + The class is designed to work with HuntressEndpoint and its derived classes to + parse the API response into model instances. It also supports iteration, allowing + the user to loop through the items within the paginated response. + + PaginatedResponse uses a generic type variable TModel, which represents the + expected model type for the response data. This allows for type-safe handling + of model instances throughout the class. + """ + self._initialize(response, response_model, endpoint, page, page_size, params) + + def _initialize( # noqa: ANN202 + self, + response: Response, + response_model: type[TModel], + endpoint: IPaginateable, + page: int, + page_size: int, + params: RequestParams | None = None, + ): + """ + Initialize the instance variables using the provided response, endpoint, and page size. + + Args: + response: The raw response object from the API. + endpoint (HuntressEndpoint[TModel]): The endpoint associated with the response. + page_size (int): The number of items per page. + """ + self.response = response + self.response_model = response_model + self.endpoint = endpoint + self.page_size = page_size + # The following for SIEM is in the response body, not the headers + self.parsed_pagination_response = None #parse_link_headers(response.headers) + self.params = params + if self.parsed_pagination_response is not None: + # Huntress SIEM API gives us a handy response to parse for Pagination + self.has_next_page: bool = self.parsed_link_headers.get("has_next_page", False) + self.has_prev_page: bool = self.parsed_link_headers.get("has_prev_page", False) + self.first_page: int = self.parsed_link_headers.get("first_page", None) + self.prev_page: int = self.parsed_link_headers.get("prev_page", None) + self.next_page: int = self.parsed_link_headers.get("next_page", None) + self.last_page: int = self.parsed_link_headers.get("last_page", None) + else: + # Huntress Managed SAT might, haven't worked on this yet + self.has_next_page: bool = True + self.has_prev_page: bool = page > 1 + self.first_page: int = 1 + self.prev_page = page - 1 if page > 1 else 1 + self.next_page = page + 1 + self.last_page = 999999 + self.data: list[TModel] = [response_model.model_validate(d) for d in response.json()] + self.has_data = self.data and len(self.data) > 0 + self.index = 0 + + def get_next_page(self) -> PaginatedResponse[TModel]: + """ + Fetch the next page of the paginated response. + + Returns: + PaginatedResponse[TModel]: The updated PaginatedResponse instance + with the data from the next page or None if there is no next page. + """ + if not self.has_next_page or not self.next_page: + self.has_data = False + return self + + next_response = self.endpoint.paginated(self.next_page, self.page_size, self.params) + self._initialize( + next_response.response, + next_response.response_model, + next_response.endpoint, + self.next_page, + next_response.page_size, + self.params, + ) + return self + + def get_previous_page(self) -> PaginatedResponse[TModel]: + """ + Fetch the next page of the paginated response. + + Returns: + PaginatedResponse[TModel]: The updated PaginatedResponse instance + with the data from the next page or None if there is no next page. + """ + if not self.has_prev_page or not self.prev_page: + self.has_data = False + return self + + prev_response = self.endpoint.paginated(self.prev_page, self.page_size, self.params) + self._initialize( + prev_response.response, + prev_response.response_model, + prev_response.endpoint, + self.prev_page, + prev_response.page_size, + self.params, + ) + return self + + def all(self) -> Iterable[TModel]: # noqa: A003 + """ + Iterate through all items in the paginated response, across all pages. + + Yields: + TModel: An instance of the model class for each item in the paginated response. + """ + while self.has_data: + yield from self.data + self.get_next_page() + + def __iter__(self): # noqa: ANN204 + """ + Implement the iterator protocol for the PaginatedResponse class. + + Returns: + PaginatedResponse[TModel]: The current instance of the PaginatedResponse. + """ + return self + + def __dict__(self): # noqa: ANN204 + """ + Implement the iterator protocol for the PaginatedResponse class. + + Returns: + PaginatedResponse[TModel]: The current instance of the PaginatedResponse. + """ + return self.data + + def __next__(self): # noqa: ANN204 + """ + Implement the iterator protocol by getting the next item in the data. + + Returns: + TModel: The next item in the data. + + Raises: + StopIteration: If there are no more items in the data. + """ + if self.index < len(self.data): + result = self.data[self.index] + self.index += 1 + return result + else: # noqa: RET505 + raise StopIteration diff --git a/src/pyhuntress/types.py b/src/pyhuntress/types.py new file mode 100644 index 0000000..70773a8 --- /dev/null +++ b/src/pyhuntress/types.py @@ -0,0 +1,53 @@ +from typing import Literal, TypeAlias + +from typing_extensions import NotRequired, TypedDict +from datetime import datetime + +Literals: TypeAlias = str | int | float | bool +JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | Literals | None + + +class Patch(TypedDict): + op: Literal["add"] | Literal["replace"] | Literal["remove"] + path: str + value: JSON + + +class HuntressSATRequestParams(TypedDict): + conditions: NotRequired[str] + childConditions: NotRequired[str] + customFieldConditions: NotRequired[str] + orderBy: NotRequired[str] + page: NotRequired[int] + pageSize: NotRequired[int] + fields: NotRequired[str] + columns: NotRequired[str] + + +class HuntressSIEMRequestParams(TypedDict): + created_at_min: NotRequired[datetime] + created_at_max: NotRequired[datetime] + updated_at_min: NotRequired[datetime] + updated_at_min: NotRequired[datetime] + customFieldConditions: NotRequired[str] + page_token: NotRequired[str] + page: NotRequired[int] + limit: NotRequired[int] + organization_id: NotRequired[int] + platform: NotRequired[str] + status: NotRequired[str] + indicator_type: NotRequired[str] + severity: NotRequired[str] + platform: NotRequired[str] + agent_id: NotRequired[str] + type: NotRequired[str] + entity_id: NotRequired[int] + types: NotRequired[str] + statuses: NotRequired[str] + + +GenericRequestParams: TypeAlias = dict[str, Literals] +RequestParams: TypeAlias = HuntressSATRequestParams | HuntressSIEMRequestParams | GenericRequestParams +PatchRequestData: TypeAlias = list[Patch] +RequestData: TypeAlias = JSON | PatchRequestData +RequestMethod: TypeAlias = Literal["GET", "POST", "PUT", "PATCH", "DELETE"] diff --git a/src/pyhuntress/utils/__init__.py b/src/pyhuntress/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/utils/experimental/__init__.py b/src/pyhuntress/utils/experimental/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pyhuntress/utils/experimental/condition.py b/src/pyhuntress/utils/experimental/condition.py new file mode 100644 index 0000000..87b4f76 --- /dev/null +++ b/src/pyhuntress/utils/experimental/condition.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import inspect +import re +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from pyhuntress.utils.naming import to_camel_case + +if TYPE_CHECKING: + from collections.abc import Callable + +T = TypeVar("T") + + +class ValueType(Enum): + STR = 1 + INT = 2 + DATETIME = 3 + + +class Condition(Generic[T]): + def __init__(self: Condition[T]) -> None: + self._condition_string: str = "" + self._field = "" + + def field(self: Condition[T], selector: Callable[[type[T]], Any]) -> Condition[T]: + field = "" + + frame = inspect.currentframe() + try: + context = inspect.getframeinfo(frame.f_back).code_context + caller_lines = "".join([line.strip() for line in context]) + m = re.search(r"field\s*\(([^)]+)\)", caller_lines) + if m: + caller_lines = m.group(1) + + field = to_camel_case("/".join(caller_lines.replace("(", "").replace(")", "").split(".")[1:])) + + finally: + del frame + + self._condition_string += field + return self + + def equals(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " = " + self.__add_typed_value_to_string(value, type(value)) + return self + + def not_equals(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " = " + self.__add_typed_value_to_string(value, type(value)) + return self + + def less_than(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " < " + self.__add_typed_value_to_string(value, type(value)) + return self + + def less_than_or_equals( + self: Condition[T], + value: Any, # noqa: ANN401 + ) -> Condition[T]: + self._condition_string += " <= " + self.__add_typed_value_to_string(value, type(value)) + return self + + def greater_than(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " > " + self.__add_typed_value_to_string(value, type(value)) + return self + + def greater_than_or_equals( + self: Condition[T], + value: Any, # noqa: ANN401 + ) -> Condition[T]: + self._condition_string += " >= " + self.__add_typed_value_to_string(value, type(value)) + return self + + def contains(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " CONTAINS " + self.__add_typed_value_to_string(value, type(value)) + return self + + def like(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " LIKE " + self.__add_typed_value_to_string(value, type(value)) + return self + + def in_(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " IN " + self.__add_typed_value_to_string(value, type(value)) + return self + + def not_(self: Condition[T], value: Any) -> Condition[T]: # noqa: ANN401 + self._condition_string += " NOT " + self.__add_typed_value_to_string(value, type(value)) + return self + + def __add_typed_value_to_string( # noqa: ANN202 + self: Condition[T], + value: Any, # noqa: ANN401 + type: type, # noqa: A002 + ): + if type is str: + self._condition_string += f'"{value}"' + elif type is int: # noqa: SIM114 + self._condition_string += str(value) + elif type is bool: + self._condition_string += str(value) + elif type is datetime: + self._condition_string += f"[{value}]" + else: + self._condition_string += f'"{value}"' + + def and_(self: Condition[T], selector: Callable[[type[T]], Any] | None = None) -> Condition[T]: + self._condition_string += " AND " + + if selector is not None: + field = "" + frame = inspect.currentframe() + try: + context = inspect.getframeinfo(frame.f_back).code_context + caller_lines = "".join([line.strip() for line in context]) + m = re.search(r"and_\s*\(([^)]+)\)", caller_lines) + if m: + caller_lines = m.group(1) + + field = "/".join(caller_lines.replace("(", "").replace(")", "").split(".")[1:]) + + finally: + del frame + + self._condition_string += field + return self + + def or_(self: Condition[T], selector: Callable[[type[T]], Any] | None = None) -> Condition[T]: + self._condition_string += " OR " + + if selector is not None: + field = "" + frame = inspect.currentframe() + try: + context = inspect.getframeinfo(frame.f_back).code_context + caller_lines = "".join([line.strip() for line in context]) + m = re.search(r"or_\s*\(([^)]+)\)", caller_lines) + if m: + caller_lines = m.group(1) + + field = "/".join(caller_lines.replace("(", "").replace(")", "").split(".")[1:]) + + finally: + del frame + + self._condition_string += field + return self + + def wrap(self: Condition[T], condition: Callable[[Condition[T]], Condition[T]]) -> Condition[T]: + self._condition_string += f"({condition(Condition[T]())})" + return self + + def __str__(self: Condition[T]) -> str: + return self._condition_string.strip() diff --git a/src/pyhuntress/utils/experimental/patch_maker.py b/src/pyhuntress/utils/experimental/patch_maker.py new file mode 100644 index 0000000..4ec96fd --- /dev/null +++ b/src/pyhuntress/utils/experimental/patch_maker.py @@ -0,0 +1,37 @@ +import json +from enum import Enum +from typing import Any + + +class Patch: + class PatchOp(Enum): + """ + PatchOperation is an enumeration of the different patch operations supported + by the Huntress API. These operations are ADD, REPLACE, and REMOVE. + """ + + ADD = 1 + REPLACE = 2 + REMOVE = 3 + + def __init__(self, op: PatchOp, path: str, value: Any) -> None: # noqa: ANN401 + self.op = op.name.lower() + self.path = path + self.value = value + + def __repr__(self) -> str: + """ + Return a string representation of the model as a formatted JSON string. + + Returns: + str: A formatted JSON string representation of the model. + """ + return json.dumps(self.__dict__, default=str, indent=2) + + +class PatchGroup: + def __init__(self, *patches: Patch) -> None: + self.patches = list(patches) + + def __repr__(self) -> str: + return str(self.patches) diff --git a/src/pyhuntress/utils/helpers.py b/src/pyhuntress/utils/helpers.py new file mode 100644 index 0000000..2c49a49 --- /dev/null +++ b/src/pyhuntress/utils/helpers.py @@ -0,0 +1,101 @@ +import re +from datetime import datetime +from typing import Any + +from requests.structures import CaseInsensitiveDict + + +def cw_format_datetime(dt: datetime) -> str: + """Format a datetime object as a string in ISO 8601 format. This is the format that Huntress uses. + + Args: + dt (datetime): The datetime object to be formatted. + + Returns: + str: The formatted datetime string in the format "YYYY-MM-DDTHH:MM:SSZ". + + Example: + from datetime import datetime + + dt = datetime(2022, 1, 1, 12, 0, 0) + formatted_dt = cw_format_datetime(dt) + print(formatted_dt) # Output: "2022-01-01T12:00:00Z" + """ + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def parse_link_headers( # noqa: C901 + headers: CaseInsensitiveDict, +) -> dict[str, Any] | None: + """ + Parses link headers to extract pagination information. + + Arguments: + - headers: A dictionary containing the headers of an HTTP response. The value associated with the "Link" key should be a string representing the link headers. + + Returns: + - A dictionary containing the extracted pagination information. The keys in the dictionary include: + - "first_page": An optional integer representing the number of the first page. + - "prev_page": An optional integer representing the number of the previous page. + - "next_page": An optional integer representing the number of the next page. + - "last_page": An optional integer representing the number of the last page. + - "has_next_page": A boolean indicating whether there is a next page. + - "has_prev_page": A boolean indicating whether there is a previous page. + + If the "Link" header is not present in the headers dictionary, None is returned. + + Example Usage: + headers = { + "Link": '; rel="first", ; rel="next"' + } + pagination_info = parse_link_headers(headers) + print(pagination_info) + # Output: {'first_page': 1, 'next_page': 2, 'has_next_page': True} + """ + if headers.get("Link") is None: + return None + links = headers["Link"].split(",") + has_next_page: bool = False + has_prev_page: bool = False + first_page: int | None = None + prev_page: int | None = None + next_page: int | None = None + last_page: int | None = None + + for link in links: + match = re.search(r'page=(\d+)>; rel="(.*?)"', link) + if match: + page_number = int(match.group(1)) + rel_value = match.group(2) + if rel_value == "first": + first_page = page_number + elif rel_value == "prev": + prev_page = page_number + has_prev_page = True + elif rel_value == "next": + next_page = page_number + has_next_page = True + elif rel_value == "last": + last_page = page_number + + result = {} + + if first_page is not None: + result["first_page"] = first_page + + if prev_page is not None: + result["prev_page"] = prev_page + + if next_page is not None: + result["next_page"] = next_page + + if last_page is not None: + result["last_page"] = last_page + + if has_next_page: + result["has_next_page"] = has_next_page + + if has_prev_page: + result["has_prev_page"] = has_prev_page + + return result diff --git a/src/pyhuntress/utils/naming.py b/src/pyhuntress/utils/naming.py new file mode 100644 index 0000000..e0d1758 --- /dev/null +++ b/src/pyhuntress/utils/naming.py @@ -0,0 +1,23 @@ +from keyword import iskeyword + + +def to_snake_case(string: str) -> str: + return ("_" if string.startswith("_") else "") + "".join( + ["_" + i.lower() if i.isupper() else i for i in string.lstrip("_")] + ).lstrip("_") + + +def to_camel_case(string: str) -> str: + string_split = string.split("_") + return string_split[0] + "".join(word.capitalize() for word in string_split[1:]) + + +def to_title_case_preserve_case(string: str) -> str: + return string[:1].upper() + string[1:] + + +def ensure_not_reserved(string: str) -> str: + if iskeyword(string): + return string + "_" + else: # noqa: RET505 + return string diff --git a/src/scratchpad.py b/src/scratchpad.py new file mode 100644 index 0000000..2ce35ed --- /dev/null +++ b/src/scratchpad.py @@ -0,0 +1,43 @@ +import os +from pyhuntress import HuntressSIEMAPIClient +from dotenv import load_dotenv + +load_dotenv() + +siem_url = os.getenv('siem_url') +publickey = os.getenv('publickey') +privatekey = os.getenv('privatekey') + +# init client +siem_api_client = HuntressSIEMAPIClient( + siem_url, + publickey, + privatekey, +) + +#account = siem_api_client.account.get() +#print(account) + +#actor = siem_api_client.actor.get() +#print(actor) + +#agents = siem_api_client.agents.get() +#print(agents) + +billingreports = siem_api_client.billing_reports.get() +print(billingreports) + +incidentreports = siem_api_client.incident_reports.get() +print(incidentreports) + +organizations = siem_api_client.organizations.get() +print(organizations) + +reports = siem_api_client.reports.get() +print(reports) + +signals = siem_api_client.signals.get() +print(signals) + +#paginated_agents = siem_api_client.agents.paginated(1, 10) +#print(paginated_agents) \ No newline at end of file