Open3D (C++ API)  0.17.0
TrilinearDevoxelizeKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "../TensorFlowHelper.h"
11#include "tensorflow/core/framework/op.h"
12#include "tensorflow/core/framework/op_kernel.h"
13#include "tensorflow/core/lib/core/errors.h"
14
15class TrilinearDevoxelizeOpKernel : public tensorflow::OpKernel {
16public:
18 tensorflow::OpKernelConstruction* context)
19 : tensorflow::OpKernel(context) {
20 using namespace tensorflow;
21 OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
22 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training));
23 OP_REQUIRES(context, r > 0,
24 errors::InvalidArgument(
25 "TrilinearDevoxelize expects positive resolution"));
26 }
27
28 void Compute(tensorflow::OpKernelContext* context) override {
29 using namespace tensorflow;
30 const Tensor& coords = context->input(0);
31 OP_REQUIRES(
32 context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
33 errors::InvalidArgument("TrilinearDevoxelize expects "
34 "(batch_size, 3, N) coordinate shape"));
35 const Tensor& feat = context->input(1);
36 OP_REQUIRES(context, feat.dims() == 5,
37 errors::InvalidArgument("TrilinearDevoxelize expects "
38 "5 dimensions for features"));
39
40 int batch_size = coords.shape().dim_size(0);
41 int num_points = coords.shape().dim_size(2);
42 int feat_dim = feat.shape().dim_size(1);
43
44 auto coords_flat = coords.flat<float>();
45 auto feat_flat = feat.flat<float>();
46
47 const float* inp_coords = &(coords_flat(0));
48 const float* inp_feat = &(feat_flat(0));
49
50 Tensor* out_tensor_0;
51 OP_REQUIRES_OK(context,
52 context->allocate_output(
53 0, TensorShape{batch_size, feat_dim, num_points},
54 &out_tensor_0));
55 Tensor* out_tensor_1;
56 OP_REQUIRES_OK(context,
57 context->allocate_output(
58 1, TensorShape{batch_size, 8, num_points},
59 &out_tensor_1));
60 Tensor* out_tensor_2;
61 OP_REQUIRES_OK(context,
62 context->allocate_output(
63 2, TensorShape{batch_size, 8, num_points},
64 &out_tensor_2));
65 auto flat_0 = out_tensor_0->flat<float>();
66 auto flat_1 = out_tensor_1->flat<int>();
67 auto flat_2 = out_tensor_2->flat<float>();
68
69 float* out_0 = &(flat_0(0));
70 int* out_1 = &(flat_1(0));
71 float* out_2 = &(flat_2(0));
72
73 if (is_training)
74 Kernel(context, batch_size, feat_dim, num_points, r, r * r,
75 r * r * r, true, inp_coords, inp_feat, out_1, out_2, out_0);
76 else
77 Kernel(context, batch_size, feat_dim, num_points, r, r * r,
78 r * r * r, false, inp_coords, inp_feat, out_1, out_2, out_0);
79 }
80
81 virtual void Kernel(tensorflow::OpKernelContext* context,
82 int b,
83 int c,
84 int n,
85 int r,
86 int r2,
87 int r3,
88 bool training,
89 const float* coords,
90 const float* feat,
91 int* inds,
92 float* wgts,
93 float* outs) = 0;
94
95protected:
96 int r;
98};
99
100class TrilinearDevoxelizeGradOpKernel : public tensorflow::OpKernel {
101public:
103 tensorflow::OpKernelConstruction* context)
104 : tensorflow::OpKernel(context) {
105 using namespace tensorflow;
106 OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
107 OP_REQUIRES(
108 context, r > 0,
109 errors::InvalidArgument(
110 "TrilinearDevoxelizeGrad expects positive resolution"));
111 }
112
113 void Compute(tensorflow::OpKernelContext* context) override {
114 using namespace tensorflow;
115 const Tensor& grad_y = context->input(0);
116 OP_REQUIRES(
117 context, grad_y.dims() == 3,
118 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
119 "(batch_size, C, N) gradient shape"));
120 const Tensor& inds = context->input(1);
121 OP_REQUIRES(
122 context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
123 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
124 "(batch_size, 8, N) indices shape"));
125 const Tensor& wgts = context->input(2);
126 OP_REQUIRES(
127 context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
128 errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
129 "(batch_size, 8, N) weights shape"));
130
131 int batch_size = grad_y.shape().dim_size(0);
132 int num_points = grad_y.shape().dim_size(2);
133 int feat_dim = grad_y.shape().dim_size(1);
134
135 auto grad_y_flat = grad_y.flat<float>();
136 auto inds_flat = inds.flat<int>();
137 auto wgts_flat = wgts.flat<float>();
138
139 const float* inp_grad_y = &(grad_y_flat(0));
140 const int* inp_inds = &(inds_flat(0));
141 const float* inp_wgts = &(wgts_flat(0));
142
143 Tensor* out_tensor;
144 OP_REQUIRES_OK(context,
145 context->allocate_output(
146 0, TensorShape{batch_size, feat_dim, r, r, r},
147 &out_tensor));
148 auto flat_tensor = out_tensor->flat<float>();
149
150 float* out = &(flat_tensor(0));
151
152 Kernel(context, batch_size, feat_dim, num_points, r * r * r, inp_inds,
153 inp_wgts, inp_grad_y, out);
154 }
155
156 virtual void Kernel(tensorflow::OpKernelContext* context,
157 int b,
158 int c,
159 int n,
160 int r3,
161 const int* inds,
162 const float* wgts,
163 const float* grad_y,
164 float* grad_x) = 0;
165
166protected:
167 int r;
168};
ImGuiContext * context
Definition: Window.cpp:76
Definition: TrilinearDevoxelizeKernel.h:100
int r
Definition: TrilinearDevoxelizeKernel.h:167
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:113
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:102
Definition: TrilinearDevoxelizeKernel.h:15
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:28
int r
Definition: TrilinearDevoxelizeKernel.h:96
bool is_training
Definition: TrilinearDevoxelizeKernel.h:97