diff --git a/src/main/java/top/guoziyang/springframework/factory/AbstractBeanFactory.java b/src/main/java/top/guoziyang/springframework/factory/AbstractBeanFactory.java index aa54786..4989b49 100644 --- a/src/main/java/top/guoziyang/springframework/factory/AbstractBeanFactory.java +++ b/src/main/java/top/guoziyang/springframework/factory/AbstractBeanFactory.java @@ -2,6 +2,7 @@ import top.guoziyang.springframework.entity.BeanDefinition; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -9,19 +10,18 @@ public abstract class AbstractBeanFactory implements BeanFactory { ConcurrentHashMap beanDefinitionMap = new ConcurrentHashMap<>(); + ThreadLocal> earylyBean = new ThreadLocal>(); + @Override public Object getBean(String name) throws Exception { BeanDefinition beanDefinition = beanDefinitionMap.get(name); if(beanDefinition == null) { throw new RuntimeException("Unable to find the bean of this name, please check!"); } - if(!beanDefinition.isSingleton() || beanDefinition.getBean() == null) { - return doCreateBean(beanDefinition); - } else { - return doCreateBean(beanDefinition); - } + return getBeanFromBeanDefinition(beanDefinition); } + @Override public Object getBean(Class clazz) throws Exception { BeanDefinition beanDefinition = null; @@ -34,11 +34,36 @@ public Object getBean(Class clazz) throws Exception { if(beanDefinition == null) { throw new RuntimeException("Unable to find the bean of this class, please check!"); } - if(!beanDefinition.isSingleton() || beanDefinition.getBean() == null) { - return doCreateBean(beanDefinition); - } else { - return beanDefinition.getBean(); + return getBeanFromBeanDefinition(beanDefinition); + } + + private Object getBeanFromBeanDefinition(BeanDefinition beanDefinition) throws Exception { + if(beanDefinition.isSingleton()){ + if(beanDefinition.getBean()!=null){ + return beanDefinition.getBean(); + }else{ + /** + * bug场景 :当多个线程使用同一个BeanFactory,针对同一个单例的beanDefinition 调用getBean + * 如果没有锁,会创建多个对象 + */ + synchronized (this){ + if(beanDefinition.getBean()==null){ + return doCreateBean(beanDefinition); + }else{ + return beanDefinition.getBean(); + } + } + } + }else{ + //不是单例 先从earlyBean中找,如果没有就新创建 + HashMap earlyBeanMap = earylyBean.get(); + if(earlyBeanMap!=null && earlyBeanMap.containsKey(beanDefinition.getBeanClassName())){ + return earlyBeanMap.get(beanDefinition.getBeanClassName()); + }else{ + return doCreateBean(beanDefinition); + } } + } @Override diff --git a/src/main/java/top/guoziyang/springframework/factory/AutowiredCapableBeanFactory.java b/src/main/java/top/guoziyang/springframework/factory/AutowiredCapableBeanFactory.java index 056ed10..e247127 100644 --- a/src/main/java/top/guoziyang/springframework/factory/AutowiredCapableBeanFactory.java +++ b/src/main/java/top/guoziyang/springframework/factory/AutowiredCapableBeanFactory.java @@ -5,6 +5,7 @@ import top.guoziyang.springframework.entity.PropertyValue; import java.lang.reflect.Field; +import java.util.HashMap; public class AutowiredCapableBeanFactory extends AbstractBeanFactory { @@ -15,6 +16,8 @@ Object doCreateBean(BeanDefinition beanDefinition) throws Exception { } Object bean = beanDefinition.getBeanClass().newInstance(); if(beanDefinition.isSingleton()) { + //如果是单例,就算没有完成属性赋值,也可以存起来 + //这样可以直接避免出现循环依赖导致的死循环问题 beanDefinition.setBean(bean); } applyPropertyValues(bean, beanDefinition); @@ -36,25 +39,16 @@ void applyPropertyValues(Object bean, BeanDefinition beanDefinition) throws Exce // 优先按照自定义名称匹配 BeanDefinition refDefinition = beanDefinitionMap.get(beanReference.getName()); if(refDefinition != null) { - if(!refDefinition.isSingleton() || refDefinition.getBean() == null) { - value = doCreateBean(refDefinition); - } else { - value = refDefinition.getBean(); - } + value = createBeanFromBeanDefinition(bean, beanDefinition, beanReference, refDefinition); } else { // 按照类型匹配,返回第一个匹配的 Class clazz = Class.forName(beanReference.getName()); for(BeanDefinition definition : beanDefinitionMap.values()) { if(clazz.isAssignableFrom(definition.getBeanClass())) { - if(!definition.isSingleton() || definition.getBean() == null) { - value = doCreateBean(definition); - } else { - value = definition.getBean(); - } + value = createBeanFromBeanDefinition(bean, beanDefinition, beanReference, definition); } } } - } if(value == null) { throw new RuntimeException("无法注入"); @@ -62,6 +56,32 @@ void applyPropertyValues(Object bean, BeanDefinition beanDefinition) throws Exce field.setAccessible(true); field.set(bean, value); } + //如果自己在earlyBean里,就删除 + if(earylyBean.get()!=null && earylyBean.get().containsKey(beanDefinition.getBeanClassName())){ + earylyBean.get().remove(beanDefinition.getBeanClassName()); + } + } + + private Object createBeanFromBeanDefinition(Object bean, BeanDefinition beanDefinition, BeanReference beanReference, + BeanDefinition refDefinition) throws Exception { + if(refDefinition.isSingleton()){ + //单例就直接拿 + if(refDefinition.getBean()!=null){ + return refDefinition.getBean(); + }else{ + return doCreateBean(refDefinition); + } + }else{ + //先把自己放入earlyBean + if(earylyBean.get() == null){ + earylyBean.set(new HashMap<>()); + } + if(!earylyBean.get().containsKey(beanDefinition.getBeanClassName())){ + earylyBean.get().put(beanDefinition.getBeanClassName(), bean); + } + //再尝试获取所需的Bean + return getBean(beanReference.getName()); + } }