开发者社区> 问答> 正文

如何将分类数据类型转换为numpy数组?

我试过这个代码。输出是:Categorical(probs: torch.Size([12])))。我想从输出中提取值并将其转换为numpy数组。谁能提出建议?

我知道我可以返回a的值(注释)。但是,还有什么解决方案吗?

class Actor(nn.Module):
def __init__(self, state_size, action_size):

super(Actor, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.linear1 = nn.Linear(self.state_size, 128)
self.linear2 = nn.Linear(128, 256)
self.linear3 = nn.Linear(256, self.action_size)

def forward(self, state):

output = F.relu(self.linear1(state))
output = F.relu(self.linear2(output))
output = self.linear3(output)
distribution = Categorical(F.softmax(output, dim=-1))
# a=F.softmax(output, dim=-1)
# print(a.detach().numpy())
return distribution

output=Actor(state)
print(output)

展开
收起
一码平川MACHEL 2019-01-18 10:38:26 3216 0
1 条回答
写回答
取消 提交回答
  • prob = output.probs.detach().cpu().numpy()

    2019-07-17 23:25:51
    赞同 展开评论 打赏
问答分类:
问答标签:
问答地址:
问答排行榜
最热
最新

相关电子书

更多
低代码开发师(初级)实战教程 立即下载
冬季实战营第三期:MySQL数据库进阶实战 立即下载
阿里巴巴DevOps 最佳实践手册 立即下载