Add new entrypoint option to --load-extensions. (#1789)

Thanks, @asg017
pull/1793/head
Alex Garcia 2022-08-23 11:34:30 -07:00 zatwierdzone przez GitHub
rodzic 663ac431fe
commit 1d64c9a8da
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
6 zmienionych plików z 140 dodań i 2 usunięć

6
.gitignore vendored
Wyświetl plik

@ -118,3 +118,9 @@ ENV/
.DS_Store
node_modules
.*.swp
# In case someone compiled tests/ext.c for test_load_extensions, don't
# include it in source control.
tests/*.dylib
tests/*.so
tests/*.dll

Wyświetl plik

@ -559,7 +559,13 @@ class Datasette:
if self.sqlite_extensions:
conn.enable_load_extension(True)
for extension in self.sqlite_extensions:
conn.execute("SELECT load_extension(?)", [extension])
# "extension" is either a string path to the extension
# or a 2-item tuple that specifies which entrypoint to load.
if isinstance(extension, tuple):
path, entrypoint = extension
conn.execute("SELECT load_extension(?, ?)", [path, entrypoint])
else:
conn.execute("SELECT load_extension(?)", [extension])
if self.setting("cache_size_kb"):
conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}")
# pylint: disable=no-member

Wyświetl plik

@ -21,6 +21,7 @@ from .app import (
pm,
)
from .utils import (
LoadExtension,
StartupError,
check_connection,
find_spatialite,
@ -128,9 +129,10 @@ def sqlite_extensions(fn):
return click.option(
"sqlite_extensions",
"--load-extension",
type=LoadExtension(),
envvar="SQLITE_EXTENSIONS",
multiple=True,
help="Path to a SQLite extension to load",
help="Path to a SQLite extension to load, and optional entrypoint",
)(fn)

Wyświetl plik

@ -833,6 +833,17 @@ class StaticMount(click.ParamType):
self.fail(f"{value} is not a valid directory path", param, ctx)
return path, dirpath
# The --load-extension parameter can optionally include a specific entrypoint.
# This is done by appending ":entrypoint_name" after supplying the path to the extension
class LoadExtension(click.ParamType):
name = "path:entrypoint?"
def convert(self, value, param, ctx):
if ":" not in value:
return value
path, entrypoint = value.split(":", 1)
return path, entrypoint
def format_bytes(bytes):
current = float(bytes)

48
tests/ext.c 100644
Wyświetl plik

@ -0,0 +1,48 @@
/*
** This file implements a SQLite extension with multiple entrypoints.
**
** The default entrypoint, sqlite3_ext_init, has a single function "a".
** The 1st alternate entrypoint, sqlite3_ext_b_init, has a single function "b".
** The 2nd alternate entrypoint, sqlite3_ext_c_init, has a single function "c".
**
** Compiling instructions:
** https://www.sqlite.org/loadext.html#compiling_a_loadable_extension
**
*/
#include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1
// SQL function that returns back the value supplied during sqlite3_create_function()
static void func(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_text(context, (char *) sqlite3_user_data(context), -1, SQLITE_STATIC);
}
// The default entrypoint, since it matches the "ext.dylib"/"ext.so" name
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "a", 0, 0, "a", func, 0, 0);
}
// Alternate entrypoint #1
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_b_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "b", 0, 0, "b", func, 0, 0);
}
// Alternate entrypoint #2
#ifdef _WIN32
__declspec(dllexport)
#endif
int sqlite3_ext_c_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
SQLITE_EXTENSION_INIT2(pApi);
return sqlite3_create_function(db, "c", 0, 0, "c", func, 0, 0);
}

Wyświetl plik

@ -0,0 +1,65 @@
from datasette.app import Datasette
import pytest
from pathlib import Path
# not necessarily a full path - the full compiled path looks like "ext.dylib"
# or another suffix, but sqlite will, under the hood, decide which file
# extension to use based on the operating system (apple=dylib, windows=dll etc)
# this resolves to "./ext", which is enough for SQLite to calculate the rest
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
# See if ext.c has been compiled, based off the different possible suffixes.
def has_compiled_ext():
for ext in ["dylib", "so", "dll"]:
path = Path(__file__).parent / f"ext.{ext}"
if path.is_file():
return True
return False
@pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_default_entrypoint():
# The default entrypoint only loads a() and NOT b() or c(), so those
# should fail.
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
response = await ds.client.get("/_memory.json?sql=select+a()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "a"
response = await ds.client.get("/_memory.json?sql=select+b()")
assert response.status_code == 400
assert response.json()["error"] == "no such function: b"
response = await ds.client.get("/_memory.json?sql=select+c()")
assert response.status_code == 400
assert response.json()["error"] == "no such function: c"
@pytest.mark.asyncio
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
async def test_load_extension_multiple_entrypoints():
# Load in the default entrypoint and the other 2 custom entrypoints, now
# all a(), b(), and c() should run successfully.
ds = Datasette(
sqlite_extensions=[
COMPILED_EXTENSION_PATH,
(COMPILED_EXTENSION_PATH, "sqlite3_ext_b_init"),
(COMPILED_EXTENSION_PATH, "sqlite3_ext_c_init"),
]
)
response = await ds.client.get("/_memory.json?sql=select+a()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "a"
response = await ds.client.get("/_memory.json?sql=select+b()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "b"
response = await ds.client.get("/_memory.json?sql=select+c()")
assert response.status_code == 200
assert response.json()["rows"][0][0] == "c"