[civitai] implement retrieving 'version' metadata (#7432)

This commit is contained in:
Mike Fährmann
2025-05-03 20:39:35 +02:00
parent edc67983ed
commit 7a67348fc2
2 changed files with 27 additions and 8 deletions

View File

@@ -1843,10 +1843,10 @@ Type
Default Default
``false`` ``false``
Example Example
* ``"generation"`` * ``"generation,version"``
* ``["generation"]`` * ``["generation", "version"]``
Description Description
Extract additional ``generation`` metadata. Extract additional ``generation`` and ``version`` metadata.
Note: This requires 1 additional HTTP request per image or video. Note: This requires 1 additional HTTP request per image or video.

View File

@@ -10,6 +10,7 @@
from .common import Extractor, Message from .common import Extractor, Message
from .. import text, util, exception from .. import text, util, exception
from ..cache import memcache
import itertools import itertools
import time import time
@@ -49,10 +50,11 @@ class CivitaiExtractor(Extractor):
if isinstance(metadata, str): if isinstance(metadata, str):
metadata = metadata.split(",") metadata = metadata.split(",")
elif not isinstance(metadata, (list, tuple)): elif not isinstance(metadata, (list, tuple)):
metadata = ("generation",) metadata = ("generation", "version")
self._meta_generation = ("generation" in metadata) self._meta_generation = ("generation" in metadata)
self._meta_version = ("version" in metadata)
else: else:
self._meta_generation = False self._meta_generation = self._meta_version = False
def items(self): def items(self):
models = self.models() models = self.models()
@@ -77,9 +79,12 @@ class CivitaiExtractor(Extractor):
post["publishedAt"], "%Y-%m-%dT%H:%M:%S.%fZ") post["publishedAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
data = { data = {
"post": post, "post": post,
"user": post["user"], "user": post.pop("user"),
} }
del post["user"] if self._meta_version:
data["version"] = version = self.api.model_version(
post["modelVersionId"]).copy()
data["model"] = version.pop("model")
yield Message.Directory, data yield Message.Directory, data
for file in self._image_results(images): for file in self._image_results(images):
@@ -94,6 +99,18 @@ class CivitaiExtractor(Extractor):
if self._meta_generation: if self._meta_generation:
image["generation"] = self.api.image_generationdata( image["generation"] = self.api.image_generationdata(
image["id"]) image["id"])
if self._meta_version:
if "modelVersionId" in image:
version_id = image["modelVersionId"]
else:
post = image["post"] = self.api.post(
image["postId"])
post.pop("user", None)
version_id = post["modelVersionId"]
image["version"] = version = self.api.model_version(
version_id).copy()
image["model2"] = version.pop("model")
image["date"] = text.parse_datetime( image["date"] = text.parse_datetime(
image["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ") image["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
text.nameext_from_url(url, image) text.nameext_from_url(url, image)
@@ -464,6 +481,7 @@ class CivitaiRestAPI():
endpoint = "/v1/models/{}".format(model_id) endpoint = "/v1/models/{}".format(model_id)
return self._call(endpoint) return self._call(endpoint)
@memcache(keyarg=1)
def model_version(self, model_version_id): def model_version(self, model_version_id):
endpoint = "/v1/model-versions/{}".format(model_version_id) endpoint = "/v1/model-versions/{}".format(model_version_id)
return self._call(endpoint) return self._call(endpoint)
@@ -504,7 +522,7 @@ class CivitaiTrpcAPI():
self.root = extractor.root + "/api/trpc/" self.root = extractor.root + "/api/trpc/"
self.headers = { self.headers = {
"content-type" : "application/json", "content-type" : "application/json",
"x-client-version": "5.0.542", "x-client-version": "5.0.701",
"x-client-date" : "", "x-client-date" : "",
"x-client" : "web", "x-client" : "web",
"x-fingerprint" : "undefined", "x-fingerprint" : "undefined",
@@ -576,6 +594,7 @@ class CivitaiTrpcAPI():
params = {"id": int(model_id)} params = {"id": int(model_id)}
return self._call(endpoint, params) return self._call(endpoint, params)
@memcache(keyarg=1)
def model_version(self, model_version_id): def model_version(self, model_version_id):
endpoint = "modelVersion.getById" endpoint = "modelVersion.getById"
params = {"id": int(model_version_id)} params = {"id": int(model_version_id)}