Open3D (C++ API)  0.17.0
ReduceSubarraysSumOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "tensorflow/core/framework/op.h"
11#include "tensorflow/core/framework/op_kernel.h"
12#include "tensorflow/core/lib/core/errors.h"
13
15// namespace for code that is common for all kernels
16namespace reduce_subarrays_sum_opkernel {
17
18// Base class with common code for the OpKernel implementations
19class ReduceSubarraysSumOpKernel : public tensorflow::OpKernel {
20public:
21 explicit ReduceSubarraysSumOpKernel(
22 tensorflow::OpKernelConstruction* construction)
23 : OpKernel(construction) {}
24
25 void Compute(tensorflow::OpKernelContext* context) override {
26 using namespace tensorflow;
27 static_assert(sizeof(int64) == sizeof(int64_t),
28 "int64 type is not compatible");
29
30 const Tensor& values = context->input(0);
31 OP_REQUIRES(context, values.shape().dims() == 1,
32 errors::InvalidArgument("values must be a rank 1 tensor"));
33
34 const Tensor& row_splits = context->input(1);
35 OP_REQUIRES(
36 context, row_splits.shape().dims() == 1,
37 errors::InvalidArgument("row_splits must be a rank 1 tensor"));
38
39 // special treatment for empty values vector
40 if (values.shape().dim_size(0) == 0) {
41 Tensor* sums_tensor = 0;
42 OP_REQUIRES_OK(context, context->allocate_output(0, values.shape(),
43 &sums_tensor));
44 return;
45 }
46
47 Tensor* sums_tensor = 0;
48 TensorShape sums_shape({row_splits.shape().dim_size(0) - 1});
49 OP_REQUIRES_OK(context,
50 context->allocate_output(0, sums_shape, &sums_tensor));
51
52 Kernel(context, values, row_splits, *sums_tensor);
53 }
54
55 // Function with the device specific code
56 virtual void Kernel(tensorflow::OpKernelContext* context,
57 const tensorflow::Tensor& values,
58 const tensorflow::Tensor& row_splits,
59 tensorflow::Tensor& sums) = 0;
60
61private:
62};
63
64} // namespace reduce_subarrays_sum_opkernel
ImGuiContext * context
Definition: Window.cpp:76