From 71fa4b2e7cd391af37e5c599ef60e9ac8112fb7a Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Wed, 5 Jun 2024 01:46:52 -0500 Subject: [PATCH] feat: add SearchAndCrawl --- examples/dexter/analyze.ts | 30 ++++++++ src/tools/search-and-crawl.ts | 135 ++++++++++++++++++++++++++++++++++ src/url-utils.ts | 79 +++++++++----------- src/utils.ts | 9 ++- 4 files changed, 205 insertions(+), 48 deletions(-) create mode 100644 examples/dexter/analyze.ts create mode 100644 src/tools/search-and-crawl.ts diff --git a/examples/dexter/analyze.ts b/examples/dexter/analyze.ts new file mode 100644 index 0000000..f8df18e --- /dev/null +++ b/examples/dexter/analyze.ts @@ -0,0 +1,30 @@ +#!/usr/bin/env node +import 'dotenv/config' + +import { ChatModel, createAIRunner } from '@dexaai/dexter' + +import { DiffbotClient, SerpAPIClient } from '../../src/index.js' +import { createDexterFunctions } from '../../src/sdks/dexter.js' +import { SearchAndCrawl } from '../../src/tools/search-and-crawl.js' + +async function main() { + const serpapi = new SerpAPIClient() + const diffbot = new DiffbotClient() + const searchAndCrawl = new SearchAndCrawl({ serpapi, diffbot }) + + const runner = createAIRunner({ + chatModel: new ChatModel({ + params: { model: 'gpt-4o', temperature: 0 } + // debug: true + }), + functions: createDexterFunctions(searchAndCrawl), + systemMessage: + 'You are a McKinsey analyst who is an expert at writing executive summaries. Always cite your sources and respond using Markdown.' + }) + + const topic = 'the 2024 olympics' + const result = await runner(`Summarize the latest news on ${topic}`) + console.log(result) +} + +await main() diff --git a/src/tools/search-and-crawl.ts b/src/tools/search-and-crawl.ts new file mode 100644 index 0000000..8c29e47 --- /dev/null +++ b/src/tools/search-and-crawl.ts @@ -0,0 +1,135 @@ +import pMap from 'p-map' +import { z } from 'zod' + +import { aiFunction, AIFunctionsProvider } from '../fns.js' +import { type diffbot, DiffbotClient } from '../services/diffbot-client.js' +import { SerpAPIClient } from '../services/serpapi-client.js' +import { isValidCrawlableUrl, normalizeUrl } from '../url-utils.js' +import { omit, pick } from '../utils.js' + +export class SearchAndCrawl extends AIFunctionsProvider { + readonly serpapi: SerpAPIClient + readonly diffbot: DiffbotClient + + constructor(opts: { serpapi?: SerpAPIClient; diffbot?: DiffbotClient } = {}) { + super() + + this.serpapi = opts.serpapi ?? new SerpAPIClient() + this.diffbot = opts.diffbot ?? new DiffbotClient() + } + + @aiFunction({ + name: 'search_and_crawl', + description: + 'Uses Google to search the web, crawls the results, and then summarizes the most relevant results.', + inputSchema: z.object({ + query: z.string().describe('search query') + }) + }) + async searchAndCrawl({ + query, + numSearchResults = 3, + maxCrawlDepth = 1, + maxListItems = 3 + }: { + query: string + numSearchResults?: number + maxCrawlDepth?: number + maxListItems?: number + }) { + const crawledUrls = new Set() + + const crawlAndScrape = async ( + url: string | undefined, + { + depth = 0 + }: { + depth?: number + } + ): Promise => { + try { + if (!url) return [] + if (!isValidCrawlableUrl(url)) return [] + if (crawledUrls.has(url)) return [] + + const normalizedUrl = normalizeUrl(url) + if (!normalizedUrl) return [] + if (crawledUrls.has(normalizedUrl)) return [] + + crawledUrls.add(url) + crawledUrls.add(normalizedUrl) + + console.log('\n\n') + const scrapeResult = await this.diffbot.analyzeUrl({ url }) + console.log( + `SearchAndCrawl depth ${depth} - "${url}"`, + pick(scrapeResult, 'type', 'title') + ) + + if (scrapeResult.type !== 'list') { + return [scrapeResult] + } + + if (depth >= maxCrawlDepth) { + return [scrapeResult] + } + + const object = scrapeResult.objects?.[0] + if (!object) return [scrapeResult] + + const items = object.items + ?.filter((item) => item.link) + .slice(0, maxListItems) + if (!items?.length) return [scrapeResult] + + const innerScrapeResults = ( + await pMap( + items, + async (item) => { + const innerScrapeResult = await crawlAndScrape(item.link, { + depth: depth + 1 + }) + return innerScrapeResult + }, + { + concurrency: 4 + } + ) + ).flat() + + return innerScrapeResults + } catch (err) { + console.warn('crawlAndScrape error', url, err) + return [] + } + } + + const searchResponse = await this.serpapi.search({ + q: query, + num: numSearchResults + }) + + console.log(`SearchAndCrawl search results "${query}"`, searchResponse) + const scrapeResults = ( + await pMap( + (searchResponse.organic_results || []).slice(0, numSearchResults), + async (searchResult) => { + return crawlAndScrape(searchResult.link, { + depth: 0 + }) + }, + { + concurrency: 5 + } + ) + ).flat() + + const output = { + ...omit(searchResponse, 'organic_results'), + scrape_results: scrapeResults + } + + console.log(`SearchAndCrawl response for query "${query}"`, output) + return output + } +} diff --git a/src/url-utils.ts b/src/url-utils.ts index 3fab95a..aabb749 100644 --- a/src/url-utils.ts +++ b/src/url-utils.ts @@ -1,20 +1,16 @@ import isRelativeUrlImpl from 'is-relative-url' -import normalizeUrlImpl, { type Options } from 'normalize-url' +import normalizeUrlImpl, { + type Options as NormalizeUrlOptions +} from 'normalize-url' import QuickLRU from 'quick-lru' import { hashObject } from './utils.js' const protocolAllowList = new Set(['https:', 'http:']) -const normalizedUrlCache = new QuickLRU({ +const normalizedUrlCache = new QuickLRU({ maxSize: 4000 }) -/** - * Checks if a URL is crawlable. - * - * @param url - URL string to check - * @returns whether the URL is crawlable - */ export function isValidCrawlableUrl(url: string): boolean { try { if (!url || isRelativeUrl(url)) { @@ -43,42 +39,35 @@ export function isRelativeUrl(url: string): boolean { return isRelativeUrlImpl(url) && !url.startsWith('//') } -/** - * Normalizes a URL string. - * - * @param url - URL string to normalize - * @param options - options for normalization. - * @returns normalized URL string or null if an invalid URL was passed - */ export function normalizeUrl( url: string, - options?: Options -): string | undefined { - let normalizedUrl: string | undefined - let cacheKey: string | undefined + options?: NormalizeUrlOptions +): string | null { + let normalizedUrl: string | null | undefined + + if (!url || isRelativeUrl(url)) { + return null + } + + const opts = { + stripWWW: false, + defaultProtocol: 'https', + normalizeProtocol: true, + forceHttps: false, + stripHash: false, + stripTextFragment: true, + removeQueryParameters: [/^utm_\w+/i, 'ref', 'ref_src'], + removeTrailingSlash: true, + removeSingleSlash: true, + removeExplicitPort: true, + sortQueryParameters: true, + ...options + } as Required + + const optionsHash = hashObject(opts) + const cacheKey = `${url}-${optionsHash}` try { - if (!url || isRelativeUrl(url)) { - return - } - - const opts = { - stripWWW: false, - defaultProtocol: 'https', - normalizeProtocol: true, - forceHttps: false, - stripHash: false, - stripTextFragment: true, - removeQueryParameters: [/^utm_\w+/i, 'ref', 'ref_src'], - removeTrailingSlash: true, - removeSingleSlash: true, - removeExplicitPort: true, - sortQueryParameters: true, - ...options - } as Required - - const optionsHash = hashObject(opts) - cacheKey = `${url}-${optionsHash}` normalizedUrl = normalizedUrlCache.get(cacheKey) if (normalizedUrl !== undefined) { @@ -86,14 +75,14 @@ export function normalizeUrl( } normalizedUrl = normalizeUrlImpl(url, opts) + if (!normalizeUrl) { + normalizedUrl = null + } } catch { // ignore invalid urls - normalizedUrl = undefined - } - - if (cacheKey) { - normalizedUrlCache.set(cacheKey, normalizedUrl!) + normalizedUrl = null } + normalizedUrlCache.set(cacheKey, normalizedUrl!) return normalizedUrl } diff --git a/src/utils.ts b/src/utils.ts index 7d6d0d1..2f033e8 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,6 +1,6 @@ import type { Jsonifiable } from 'type-fest' import dedent from 'dedent' -import hashObjectImpl from 'hash-object' +import hashObjectImpl, { type Options as HashObjectOptions } from 'hash-object' import type * as types from './types.js' @@ -142,6 +142,9 @@ export function cleanStringForModel(text: string): string { return dedenter(text).trim() } -export function hashObject(object: Record): string { - return hashObjectImpl(object, { algorithm: 'sha256' }) +export function hashObject( + object: Record, + options?: HashObjectOptions +): string { + return hashObjectImpl(object, { algorithm: 'sha256', ...options }) }