Add Transients/Factories to Dependency Injection

This commit is contained in:
Christopher F
2016-12-14 16:12:02 -05:00
parent a1addd4016
commit 0f334d24a0
2 changed files with 24 additions and 6 deletions

View File

@@ -5,21 +5,29 @@ namespace Discord.Commands
{ {
public class DependencyMap : IDependencyMap public class DependencyMap : IDependencyMap
{ {
private Dictionary<Type, object> map; private Dictionary<Type, Func<object>> map;
public static DependencyMap Empty => new DependencyMap(); public static DependencyMap Empty => new DependencyMap();
public DependencyMap() public DependencyMap()
{ {
map = new Dictionary<Type, object>(); map = new Dictionary<Type, Func<object>>();
} }
public void Add<T>(T obj) public void Add<T>(T obj) where T : class
=> AddFactory(() => obj);
public void AddTransient<T>() where T : class, new()
=> AddFactory(() => new T());
public void AddTransient<TKey, TImpl>() where TKey : class
where TImpl : class, TKey, new()
=> AddFactory<TKey>(() => new TImpl());
public void AddFactory<T>(Func<T> factory) where T : class
{ {
var t = typeof(T); var t = typeof(T);
if (map.ContainsKey(t)) if (map.ContainsKey(t))
throw new InvalidOperationException($"The dependency map already contains \"{t.FullName}\""); throw new InvalidOperationException($"The dependency map already contains \"{t.FullName}\"");
map.Add(t, obj); map.Add(t, factory);
} }
public T Get<T>() public T Get<T>()
@@ -51,7 +59,14 @@ namespace Discord.Commands
} }
public bool TryGet(Type t, out object result) public bool TryGet(Type t, out object result)
{ {
return map.TryGetValue(t, out result); Func<object> func;
if (map.TryGetValue(t, out func))
{
result = func();
return true;
}
result = null;
return false;
} }
} }
} }

View File

@@ -4,7 +4,10 @@ namespace Discord.Commands
{ {
public interface IDependencyMap public interface IDependencyMap
{ {
void Add<T>(T obj); void Add<T>(T obj) where T : class;
void AddTransient<T>() where T : class, new();
void AddTransient<TKey, TImpl>() where TKey: class where TImpl : class, TKey, new();
void AddFactory<T>(Func<T> factory) where T : class;
T Get<T>(); T Get<T>();
bool TryGet<T>(out T result); bool TryGet<T>(out T result);