kopia lustrzana https://github.com/simonw/datasette
				
				
				
			
							rodzic
							
								
									0209a0a344
								
							
						
					
					
						commit
						7d0f668556
					
				| 
						 | 
				
			
			@ -116,6 +116,43 @@ async def favicon(request):
 | 
			
		|||
    return response.text("")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConnectedDatabase:
 | 
			
		||||
    def __init__(self, path=None, is_mutable=False, is_memory=False):
 | 
			
		||||
        self.path = path
 | 
			
		||||
        self.is_mutable = is_mutable
 | 
			
		||||
        self.is_memory = is_memory
 | 
			
		||||
        self.hash = None
 | 
			
		||||
        self.size = None
 | 
			
		||||
        if not self.is_mutable:
 | 
			
		||||
            p = Path(path)
 | 
			
		||||
            self.hash = inspect_hash(p)
 | 
			
		||||
            self.size = p.stat().st_size
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
        if self.is_memory:
 | 
			
		||||
            return ":memory:"
 | 
			
		||||
        else:
 | 
			
		||||
            return Path(self.path).stem
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        tags = []
 | 
			
		||||
        if self.is_mutable:
 | 
			
		||||
            tags.append("mutable")
 | 
			
		||||
        if self.is_memory:
 | 
			
		||||
            tags.append("memory")
 | 
			
		||||
        if self.hash:
 | 
			
		||||
            tags.append("hash={}".format(self.hash))
 | 
			
		||||
        if self.size is not None:
 | 
			
		||||
            tags.append("size={}".format(self.size))
 | 
			
		||||
        tags_str = ""
 | 
			
		||||
        if tags:
 | 
			
		||||
            tags_str = " ({})".format(", ".join(tags))
 | 
			
		||||
        return "<ConnectedDatabase: {}{}>".format(
 | 
			
		||||
            self.name, tags_str
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Datasette:
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -141,6 +178,18 @@ class Datasette:
 | 
			
		|||
            self.files = [MEMORY]
 | 
			
		||||
        elif memory:
 | 
			
		||||
            self.files = (MEMORY,) + self.files
 | 
			
		||||
        self.databases = {}
 | 
			
		||||
        for file in self.files:
 | 
			
		||||
            path = file
 | 
			
		||||
            is_memory = False
 | 
			
		||||
            if file is MEMORY:
 | 
			
		||||
                path = None
 | 
			
		||||
                is_memory = True
 | 
			
		||||
            db = ConnectedDatabase(path, is_mutable=path not in self.immutables, is_memory=is_memory)
 | 
			
		||||
            if db.name in self.databases:
 | 
			
		||||
                raise Exception("Multiple files with same stem: {}".format(db.name))
 | 
			
		||||
            self.databases[db.name] = db
 | 
			
		||||
        print(self.databases)
 | 
			
		||||
        self.cache_headers = cache_headers
 | 
			
		||||
        self.cors = cors
 | 
			
		||||
        self._inspect = inspect_data
 | 
			
		||||
| 
						 | 
				
			
			@ -419,17 +468,17 @@ class Datasette:
 | 
			
		|||
        def in_thread():
 | 
			
		||||
            conn = getattr(connections, db_name, None)
 | 
			
		||||
            if not conn:
 | 
			
		||||
                info = self.inspect()[db_name]
 | 
			
		||||
                if info["file"] == ":memory:":
 | 
			
		||||
                db = self.databases[db_name]
 | 
			
		||||
                if db.is_memory:
 | 
			
		||||
                    conn = sqlite3.connect(":memory:")
 | 
			
		||||
                else:
 | 
			
		||||
                    # mode=ro or immutable=1?
 | 
			
		||||
                    if info["file"] in self.immutables:
 | 
			
		||||
                        qs = "immutable=1"
 | 
			
		||||
                    else:
 | 
			
		||||
                    if db.is_mutable:
 | 
			
		||||
                        qs = "mode=ro"
 | 
			
		||||
                    else:
 | 
			
		||||
                        qs = "immutable=1"
 | 
			
		||||
                    conn = sqlite3.connect(
 | 
			
		||||
                        "file:{}?{}".format(info["file"], qs),
 | 
			
		||||
                        "file:{}?{}".format(db.path, qs),
 | 
			
		||||
                        uri=True,
 | 
			
		||||
                        check_same_thread=False,
 | 
			
		||||
                    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -156,14 +156,13 @@ class BaseView(RenderMixin):
 | 
			
		|||
        return r
 | 
			
		||||
 | 
			
		||||
    async def resolve_db_name(self, request, db_name, **kwargs):
 | 
			
		||||
        databases = self.ds.inspect()
 | 
			
		||||
        hash = None
 | 
			
		||||
        name = None
 | 
			
		||||
        if "-" in db_name:
 | 
			
		||||
            # Might be name-and-hash, or might just be
 | 
			
		||||
            # a name with a hyphen in it
 | 
			
		||||
            name, hash = db_name.rsplit("-", 1)
 | 
			
		||||
            if name not in databases:
 | 
			
		||||
            if name not in self.ds.databases:
 | 
			
		||||
                # Try the whole name
 | 
			
		||||
                name = db_name
 | 
			
		||||
                hash = None
 | 
			
		||||
| 
						 | 
				
			
			@ -171,11 +170,13 @@ class BaseView(RenderMixin):
 | 
			
		|||
            name = db_name
 | 
			
		||||
        # Verify the hash
 | 
			
		||||
        try:
 | 
			
		||||
            info = databases[name]
 | 
			
		||||
            db = self.ds.databases[name]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise NotFound("Database not found: {}".format(name))
 | 
			
		||||
 | 
			
		||||
        expected = info["hash"][:HASH_LENGTH]
 | 
			
		||||
        expected = "000"
 | 
			
		||||
        if db.hash is not None:
 | 
			
		||||
            expected = db.hash[:HASH_LENGTH]
 | 
			
		||||
        correct_hash_provided = (expected == hash)
 | 
			
		||||
 | 
			
		||||
        if not correct_hash_provided:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ def make_app_client(
 | 
			
		|||
    cors=False,
 | 
			
		||||
    config=None,
 | 
			
		||||
    filename="fixtures.db",
 | 
			
		||||
    is_immutable=False,
 | 
			
		||||
):
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        filepath = os.path.join(tmpdir, filename)
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +49,8 @@ def make_app_client(
 | 
			
		|||
            }
 | 
			
		||||
        )
 | 
			
		||||
        ds = Datasette(
 | 
			
		||||
            [filepath],
 | 
			
		||||
            [] if is_immutable else [filepath],
 | 
			
		||||
            immutables=[filepath] if is_immutable else [],
 | 
			
		||||
            cors=cors,
 | 
			
		||||
            metadata=METADATA,
 | 
			
		||||
            plugins_dir=plugins_dir,
 | 
			
		||||
| 
						 | 
				
			
			@ -76,8 +78,8 @@ def app_client_no_files():
 | 
			
		|||
@pytest.fixture(scope="session")
 | 
			
		||||
def app_client_with_hash():
 | 
			
		||||
    yield from make_app_client(config={
 | 
			
		||||
        'hash_urls': True
 | 
			
		||||
    })
 | 
			
		||||
        'hash_urls': True,
 | 
			
		||||
    }, is_immutable=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope='session')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1317,10 +1317,10 @@ def test_ttl_parameter(app_client, path, expected_cache_control):
 | 
			
		|||
    ("/fixtures/facetable.json?_hash=1", "/fixtures-HASH/facetable.json"),
 | 
			
		||||
    ("/fixtures/facetable.json?city_id=1&_hash=1", "/fixtures-HASH/facetable.json?city_id=1"),
 | 
			
		||||
])
 | 
			
		||||
def test_hash_parameter(app_client, path, expected_redirect):
 | 
			
		||||
def test_hash_parameter(app_client_with_hash, path, expected_redirect):
 | 
			
		||||
    # First get the current hash for the fixtures database
 | 
			
		||||
    current_hash = app_client.get("/-/inspect.json").json["fixtures"]["hash"][:7]
 | 
			
		||||
    response = app_client.get(path, allow_redirects=False)
 | 
			
		||||
    current_hash = app_client_with_hash.get("/-/inspect.json").json["fixtures"]["hash"][:7]
 | 
			
		||||
    response = app_client_with_hash.get(path, allow_redirects=False)
 | 
			
		||||
    assert response.status == 302
 | 
			
		||||
    location = response.headers["Location"]
 | 
			
		||||
    assert expected_redirect.replace("HASH", current_hash) == location
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Ładowanie…
	
		Reference in New Issue