Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def load_data(img_path,train = True):
img = Image.open(img_path).convert('RGB')
gt_file = h5py.File(gt_path)
target = np.asarray(gt_file['density'])
if False:
if train:
crop_size = (img.size[0]/2,img.size[1]/2)
if random.randint(0,9)<= -1:

Expand Down Expand Up @@ -40,4 +40,4 @@ def load_data(img_path,train = True):
target = cv2.resize(target,(target.shape[1]/8,target.shape[0]/8),interpolation = cv2.INTER_CUBIC)*64


return img,target
return img,target
25 changes: 11 additions & 14 deletions val.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
"from scipy.ndimage.filters import gaussian_filter \n",
"import scipy\n",
"import json\n",
"import torchvision.transforms.functional as F\n",
"from torchvision import datasets, transforms\n",
"from matplotlib import cm as CM\n",
"from image import *\n",
"from model import CSRNet\n",
"import torch\n",
"\n",
"%matplotlib inline"
]
},
Expand All @@ -34,11 +35,11 @@
},
"outputs": [],
"source": [
"from torchvision import datasets, transforms\n",
"transform=transforms.Compose([\n",
" transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]),\n",
" ])"
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
"])"
]
},
{
Expand Down Expand Up @@ -326,17 +327,13 @@
"source": [
"mae = 0\n",
"for i in xrange(len(img_paths)):\n",
" img = 255.0 * F.to_tensor(Image.open(img_paths[i]).convert('RGB'))\n",
"\n",
" img[0,:,:]=img[0,:,:]-92.8207477031\n",
" img[1,:,:]=img[1,:,:]-95.2757037428\n",
" img[2,:,:]=img[2,:,:]-104.877445883\n",
" img = img.cuda()\n",
" #img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n",
" img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n",
" gt_file = h5py.File(img_paths[i].replace('.jpg','.h5').replace('images','ground_truth'),'r')\n",
" groundtruth = np.asarray(gt_file['density'])\n",
" \n",
" output = model(img.unsqueeze(0))\n",
" mae += abs(output.detach().cpu().sum().numpy()-np.sum(groundtruth))\n",
" \n",
" print i,mae\n",
"print mae/len(img_paths)"
]
Expand All @@ -358,7 +355,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
"version": "2.7.15"
}
},
"nbformat": 4,
Expand Down