Handle dictionary in confgw
Browse files- configuration_aimv2.py +8 -3
configuration_aimv2.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Any, Dict, Optional
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
|
|
@@ -148,8 +148,8 @@ class AIMv2Config(PretrainedConfig):
|
|
| 148 |
|
| 149 |
def __init__(
|
| 150 |
self,
|
| 151 |
-
vision_config: Optional[AIMv2VisionConfig] = None,
|
| 152 |
-
text_config: Optional[AIMv2TextConfig] = None,
|
| 153 |
projection_dim: int = 768,
|
| 154 |
init_temperature: float = 0.07,
|
| 155 |
max_logit_scale: float = 100.0,
|
|
@@ -158,8 +158,13 @@ class AIMv2Config(PretrainedConfig):
|
|
| 158 |
super().__init__(**kwargs)
|
| 159 |
if vision_config is None:
|
| 160 |
vision_config = AIMv2VisionConfig()
|
|
|
|
|
|
|
|
|
|
| 161 |
if text_config is None:
|
| 162 |
text_config = AIMv2TextConfig()
|
|
|
|
|
|
|
| 163 |
|
| 164 |
self.vision_config = vision_config
|
| 165 |
self.text_config = text_config
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Union
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
|
|
|
|
| 148 |
|
| 149 |
def __init__(
|
| 150 |
self,
|
| 151 |
+
vision_config: Optional[Union[AIMv2VisionConfig, Dict[str, Any]]] = None,
|
| 152 |
+
text_config: Optional[Union[AIMv2TextConfig, Dict[str, Any]]] = None,
|
| 153 |
projection_dim: int = 768,
|
| 154 |
init_temperature: float = 0.07,
|
| 155 |
max_logit_scale: float = 100.0,
|
|
|
|
| 158 |
super().__init__(**kwargs)
|
| 159 |
if vision_config is None:
|
| 160 |
vision_config = AIMv2VisionConfig()
|
| 161 |
+
elif isinstance(vision_config, dict):
|
| 162 |
+
vision_config = AIMv2VisionConfig(**vision_config)
|
| 163 |
+
|
| 164 |
if text_config is None:
|
| 165 |
text_config = AIMv2TextConfig()
|
| 166 |
+
elif isinstance(text_config, dict):
|
| 167 |
+
text_config = AIMv2TextConfig(**text_config)
|
| 168 |
|
| 169 |
self.vision_config = vision_config
|
| 170 |
self.text_config = text_config
|