kopia lustrzana https://github.com/simonw/datasette
rodzic
663ac431fe
commit
1d64c9a8da
|
@ -118,3 +118,9 @@ ENV/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
node_modules
|
node_modules
|
||||||
.*.swp
|
.*.swp
|
||||||
|
|
||||||
|
# In case someone compiled tests/ext.c for test_load_extensions, don't
|
||||||
|
# include it in source control.
|
||||||
|
tests/*.dylib
|
||||||
|
tests/*.so
|
||||||
|
tests/*.dll
|
|
@ -559,7 +559,13 @@ class Datasette:
|
||||||
if self.sqlite_extensions:
|
if self.sqlite_extensions:
|
||||||
conn.enable_load_extension(True)
|
conn.enable_load_extension(True)
|
||||||
for extension in self.sqlite_extensions:
|
for extension in self.sqlite_extensions:
|
||||||
conn.execute("SELECT load_extension(?)", [extension])
|
# "extension" is either a string path to the extension
|
||||||
|
# or a 2-item tuple that specifies which entrypoint to load.
|
||||||
|
if isinstance(extension, tuple):
|
||||||
|
path, entrypoint = extension
|
||||||
|
conn.execute("SELECT load_extension(?, ?)", [path, entrypoint])
|
||||||
|
else:
|
||||||
|
conn.execute("SELECT load_extension(?)", [extension])
|
||||||
if self.setting("cache_size_kb"):
|
if self.setting("cache_size_kb"):
|
||||||
conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}")
|
conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}")
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
|
|
|
@ -21,6 +21,7 @@ from .app import (
|
||||||
pm,
|
pm,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
LoadExtension,
|
||||||
StartupError,
|
StartupError,
|
||||||
check_connection,
|
check_connection,
|
||||||
find_spatialite,
|
find_spatialite,
|
||||||
|
@ -128,9 +129,10 @@ def sqlite_extensions(fn):
|
||||||
return click.option(
|
return click.option(
|
||||||
"sqlite_extensions",
|
"sqlite_extensions",
|
||||||
"--load-extension",
|
"--load-extension",
|
||||||
|
type=LoadExtension(),
|
||||||
envvar="SQLITE_EXTENSIONS",
|
envvar="SQLITE_EXTENSIONS",
|
||||||
multiple=True,
|
multiple=True,
|
||||||
help="Path to a SQLite extension to load",
|
help="Path to a SQLite extension to load, and optional entrypoint",
|
||||||
)(fn)
|
)(fn)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -833,6 +833,17 @@ class StaticMount(click.ParamType):
|
||||||
self.fail(f"{value} is not a valid directory path", param, ctx)
|
self.fail(f"{value} is not a valid directory path", param, ctx)
|
||||||
return path, dirpath
|
return path, dirpath
|
||||||
|
|
||||||
|
# The --load-extension parameter can optionally include a specific entrypoint.
|
||||||
|
# This is done by appending ":entrypoint_name" after supplying the path to the extension
|
||||||
|
class LoadExtension(click.ParamType):
|
||||||
|
name = "path:entrypoint?"
|
||||||
|
|
||||||
|
def convert(self, value, param, ctx):
|
||||||
|
if ":" not in value:
|
||||||
|
return value
|
||||||
|
path, entrypoint = value.split(":", 1)
|
||||||
|
return path, entrypoint
|
||||||
|
|
||||||
|
|
||||||
def format_bytes(bytes):
|
def format_bytes(bytes):
|
||||||
current = float(bytes)
|
current = float(bytes)
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*
|
||||||
|
** This file implements a SQLite extension with multiple entrypoints.
|
||||||
|
**
|
||||||
|
** The default entrypoint, sqlite3_ext_init, has a single function "a".
|
||||||
|
** The 1st alternate entrypoint, sqlite3_ext_b_init, has a single function "b".
|
||||||
|
** The 2nd alternate entrypoint, sqlite3_ext_c_init, has a single function "c".
|
||||||
|
**
|
||||||
|
** Compiling instructions:
|
||||||
|
** https://www.sqlite.org/loadext.html#compiling_a_loadable_extension
|
||||||
|
**
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "sqlite3ext.h"
|
||||||
|
|
||||||
|
SQLITE_EXTENSION_INIT1
|
||||||
|
|
||||||
|
// SQL function that returns back the value supplied during sqlite3_create_function()
|
||||||
|
static void func(sqlite3_context *context, int argc, sqlite3_value **argv) {
|
||||||
|
sqlite3_result_text(context, (char *) sqlite3_user_data(context), -1, SQLITE_STATIC);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// The default entrypoint, since it matches the "ext.dylib"/"ext.so" name
|
||||||
|
#ifdef _WIN32
|
||||||
|
__declspec(dllexport)
|
||||||
|
#endif
|
||||||
|
int sqlite3_ext_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
|
||||||
|
SQLITE_EXTENSION_INIT2(pApi);
|
||||||
|
return sqlite3_create_function(db, "a", 0, 0, "a", func, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternate entrypoint #1
|
||||||
|
#ifdef _WIN32
|
||||||
|
__declspec(dllexport)
|
||||||
|
#endif
|
||||||
|
int sqlite3_ext_b_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
|
||||||
|
SQLITE_EXTENSION_INIT2(pApi);
|
||||||
|
return sqlite3_create_function(db, "b", 0, 0, "b", func, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternate entrypoint #2
|
||||||
|
#ifdef _WIN32
|
||||||
|
__declspec(dllexport)
|
||||||
|
#endif
|
||||||
|
int sqlite3_ext_c_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
|
||||||
|
SQLITE_EXTENSION_INIT2(pApi);
|
||||||
|
return sqlite3_create_function(db, "c", 0, 0, "c", func, 0, 0);
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
from datasette.app import Datasette
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# not necessarily a full path - the full compiled path looks like "ext.dylib"
|
||||||
|
# or another suffix, but sqlite will, under the hood, decide which file
|
||||||
|
# extension to use based on the operating system (apple=dylib, windows=dll etc)
|
||||||
|
# this resolves to "./ext", which is enough for SQLite to calculate the rest
|
||||||
|
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
|
||||||
|
|
||||||
|
# See if ext.c has been compiled, based off the different possible suffixes.
|
||||||
|
def has_compiled_ext():
|
||||||
|
for ext in ["dylib", "so", "dll"]:
|
||||||
|
path = Path(__file__).parent / f"ext.{ext}"
|
||||||
|
if path.is_file():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
||||||
|
async def test_load_extension_default_entrypoint():
|
||||||
|
|
||||||
|
# The default entrypoint only loads a() and NOT b() or c(), so those
|
||||||
|
# should fail.
|
||||||
|
ds = Datasette(sqlite_extensions=[COMPILED_EXTENSION_PATH])
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+a()")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["rows"][0][0] == "a"
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+b()")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["error"] == "no such function: b"
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+c()")
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.json()["error"] == "no such function: c"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(not has_compiled_ext(), reason="Requires compiled ext.c")
|
||||||
|
async def test_load_extension_multiple_entrypoints():
|
||||||
|
|
||||||
|
# Load in the default entrypoint and the other 2 custom entrypoints, now
|
||||||
|
# all a(), b(), and c() should run successfully.
|
||||||
|
ds = Datasette(
|
||||||
|
sqlite_extensions=[
|
||||||
|
COMPILED_EXTENSION_PATH,
|
||||||
|
(COMPILED_EXTENSION_PATH, "sqlite3_ext_b_init"),
|
||||||
|
(COMPILED_EXTENSION_PATH, "sqlite3_ext_c_init"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+a()")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["rows"][0][0] == "a"
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+b()")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["rows"][0][0] == "b"
|
||||||
|
|
||||||
|
response = await ds.client.get("/_memory.json?sql=select+c()")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["rows"][0][0] == "c"
|
Ładowanie…
Reference in New Issue