MLIR  21.0.0git
DialectLinalg.cpp
Go to the documentation of this file.
1 //===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir-c/IR.h"
13 
14 namespace nb = nanobind;
15 using namespace mlir::python::nanobind_adaptors;
16 
17 static std::optional<MlirLinalgContractionDimensions>
18 InferContractionDimensions(MlirOperation op) {
21 
22  // Detect "empty" result. This occurs when `op` is not a contraction op,
23  // or when `linalg::inferContractionDims` fails.
24  if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25  mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26  return std::nullopt;
27  }
28  return dims;
29 }
30 
31 static std::optional<MlirLinalgConvolutionDimensions>
32 InferConvolutionDimensions(MlirOperation op) {
35 
36  // Detect "empty" result. This occurs when `op` is not a convolution op,
37  // or when `linalg::inferConvolutionDims` fails.
38  if (mlirAttributeIsNull(dims.batch) &&
45  return std::nullopt;
46  }
47 
48  return dims;
49 }
50 
51 static void populateDialectLinalgSubmodule(nb::module_ m) {
52  m.def(
53  "fill_builtin_region",
54  [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
55  nb::arg("op"),
56  "Fill the region for `op`, which is assumed to be a builtin named Linalg "
57  "op.");
58 
59  m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
60  "Checks if the given operation is a Linalg contraction operation.",
61  nb::arg("op"));
62 
63  nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
64  .def_prop_ro("batch",
65  [](const MlirLinalgContractionDimensions &self) {
66  return self.batch;
67  })
68  .def_prop_ro(
69  "m",
70  [](const MlirLinalgContractionDimensions &self) { return self.m; })
71  .def_prop_ro(
72  "n",
73  [](const MlirLinalgContractionDimensions &self) { return self.n; })
74  .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
75  return self.k;
76  });
77 
78  m.def("infer_contraction_dimensions", &InferContractionDimensions,
79  "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
80  "op.",
81  nb::arg("op"));
82 
83  m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
84  "Checks if the given operation is a Linalg convolution operation.",
85  nb::arg("op"));
86 
87  nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
88  .def_prop_ro("batch",
89  [](const MlirLinalgConvolutionDimensions &self) {
90  return self.batch;
91  })
92  .def_prop_ro("output_image",
93  [](const MlirLinalgConvolutionDimensions &self) {
94  return self.outputImage;
95  })
96  .def_prop_ro("output_channel",
97  [](const MlirLinalgConvolutionDimensions &self) {
98  return self.outputChannel;
99  })
100  .def_prop_ro("filter_loop",
101  [](const MlirLinalgConvolutionDimensions &self) {
102  return self.filterLoop;
103  })
104  .def_prop_ro("input_channel",
105  [](const MlirLinalgConvolutionDimensions &self) {
106  return self.inputChannel;
107  })
108  .def_prop_ro("depth",
109  [](const MlirLinalgConvolutionDimensions &self) {
110  return self.depth;
111  })
112  .def_prop_ro("strides",
113  [](const MlirLinalgConvolutionDimensions &self) {
114  return self.strides;
115  })
116  .def_prop_ro("dilations",
117  [](const MlirLinalgConvolutionDimensions &self) {
118  return self.dilations;
119  });
120 
121  m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122  "Infers convolution dimensions", nb::arg("op"));
123 
124  m.def(
125  "get_indexing_maps",
126  [](MlirOperation op) -> std::optional<MlirAttribute> {
127  MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
128  if (mlirAttributeIsNull(attr))
129  return std::nullopt;
130  return attr;
131  },
132  "Returns the indexing_maps attribute for a linalg op.");
133 }
134 
135 NB_MODULE(_mlirDialectsLinalg, m) {
136  m.doc() = "MLIR Linalg dialect.";
137 
139 }
static std::optional< MlirLinalgContractionDimensions > InferContractionDimensions(MlirOperation op)
static void populateDialectLinalgSubmodule(nb::module_ m)
NB_MODULE(_mlirDialectsLinalg, m)
static std::optional< MlirLinalgConvolutionDimensions > InferConvolutionDimensions(MlirOperation op)
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op)
Definition: Linalg.cpp:87
MLIR_CAPI_EXPORTED void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp)
Apply the special region builder for the builtin named Linalg op.
Definition: Linalg.cpp:18
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op)
Definition: Linalg.cpp:44
MLIR_CAPI_EXPORTED MlirAttribute mlirLinalgGetIndexingMapsAttribute(MlirOperation op)
Definition: Linalg.cpp:124
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op)
Definition: Linalg.cpp:78
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op)
Definition: Linalg.cpp:51
static bool mlirAttributeIsNull(MlirAttribute attr)
Checks whether an attribute is null.
Definition: IR.h:1145
MlirAttribute outputChannel
Definition: Linalg.h:42
MlirAttribute filterLoop
Definition: Linalg.h:43
MlirAttribute outputImage
Definition: Linalg.h:41
MlirAttribute inputChannel
Definition: Linalg.h:44