MLIR 22.0.0git
DialectLinalg.cpp
Go to the documentation of this file.
1//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir-c/IR.h"
13
14namespace nb = nanobind;
16
17static std::optional<MlirLinalgContractionDimensions>
18InferContractionDimensions(MlirOperation op) {
21
22 // Detect "empty" result. This occurs when `op` is not a contraction op,
23 // or when `linalg::inferContractionDims` fails.
24 if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
25 mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
26 return std::nullopt;
27 }
28 return dims;
29}
30
31static std::optional<MlirLinalgConvolutionDimensions>
32InferConvolutionDimensions(MlirOperation op) {
35
36 // Detect "empty" result. This occurs when `op` is not a convolution op,
37 // or when `linalg::inferConvolutionDims` fails.
38 if (mlirAttributeIsNull(dims.batch) &&
39 mlirAttributeIsNull(dims.outputImage) &&
40 mlirAttributeIsNull(dims.outputChannel) &&
41 mlirAttributeIsNull(dims.filterLoop) &&
42 mlirAttributeIsNull(dims.inputChannel) &&
43 mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
44 mlirAttributeIsNull(dims.dilations)) {
45 return std::nullopt;
46 }
47
48 return dims;
49}
50
51static void populateDialectLinalgSubmodule(nb::module_ m) {
52 m.def(
53 "fill_builtin_region",
54 [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
55 nb::arg("op"),
56 "Fill the region for `op`, which is assumed to be a builtin named Linalg "
57 "op.");
58
59 m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
60 "Checks if the given operation is a Linalg contraction operation.",
61 nb::arg("op"));
62
63 nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
64 .def_prop_ro("batch",
65 [](const MlirLinalgContractionDimensions &self) {
66 return self.batch;
67 })
68 .def_prop_ro(
69 "m",
70 [](const MlirLinalgContractionDimensions &self) { return self.m; })
71 .def_prop_ro(
72 "n",
73 [](const MlirLinalgContractionDimensions &self) { return self.n; })
74 .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
75 return self.k;
76 });
77
78 m.def("infer_contraction_dimensions", &InferContractionDimensions,
79 "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
80 "op.",
81 nb::arg("op"));
82
83 m.def(
84 "infer_contraction_dimensions_from_maps",
85 [](std::vector<MlirAffineMap> indexingMaps)
86 -> std::optional<MlirLinalgContractionDimensions> {
87 if (indexingMaps.empty())
88 return std::nullopt;
89
92 indexingMaps.size());
93
94 // Detect "empty" result from invalid input or failed inference.
95 if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
96 mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
97 return std::nullopt;
98 }
99 return dims;
100 },
101 "Infers contraction dimensions (batch/m/n/k) from a list of affine "
102 "maps.",
103 nb::arg("indexing_maps"));
104
105 m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
106 "Checks if the given operation is a Linalg convolution operation.",
107 nb::arg("op"));
108
109 nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
110 .def_prop_ro("batch",
111 [](const MlirLinalgConvolutionDimensions &self) {
112 return self.batch;
113 })
114 .def_prop_ro("output_image",
115 [](const MlirLinalgConvolutionDimensions &self) {
116 return self.outputImage;
117 })
118 .def_prop_ro("output_channel",
119 [](const MlirLinalgConvolutionDimensions &self) {
120 return self.outputChannel;
121 })
122 .def_prop_ro("filter_loop",
123 [](const MlirLinalgConvolutionDimensions &self) {
124 return self.filterLoop;
125 })
126 .def_prop_ro("input_channel",
127 [](const MlirLinalgConvolutionDimensions &self) {
128 return self.inputChannel;
129 })
130 .def_prop_ro("depth",
131 [](const MlirLinalgConvolutionDimensions &self) {
132 return self.depth;
133 })
134 .def_prop_ro("strides",
135 [](const MlirLinalgConvolutionDimensions &self) {
136 return self.strides;
137 })
138 .def_prop_ro("dilations",
139 [](const MlirLinalgConvolutionDimensions &self) {
140 return self.dilations;
141 });
142
143 m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
144 "Infers convolution dimensions", nb::arg("op"));
145
146 m.def(
147 "get_indexing_maps",
148 [](MlirOperation op) -> std::optional<MlirAttribute> {
149 MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
150 if (mlirAttributeIsNull(attr))
151 return std::nullopt;
152 return attr;
153 },
154 "Returns the indexing_maps attribute for a linalg op.");
155}
156
157NB_MODULE(_mlirDialectsLinalg, m) {
158 m.doc() = "MLIR Linalg dialect.";
159
161}
static std::optional< MlirLinalgContractionDimensions > InferContractionDimensions(MlirOperation op)
static void populateDialectLinalgSubmodule(nb::module_ m)
NB_MODULE(_mlirDialectsLinalg, m)
static std::optional< MlirLinalgConvolutionDimensions > InferConvolutionDimensions(MlirOperation op)
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op)
Definition Linalg.cpp:119
MLIR_CAPI_EXPORTED void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp)
Apply the special region builder for the builtin named Linalg op.
Definition Linalg.cpp:19
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, size_t numMaps)
Definition Linalg.cpp:79
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op)
Definition Linalg.cpp:45
MLIR_CAPI_EXPORTED MlirAttribute mlirLinalgGetIndexingMapsAttribute(MlirOperation op)
Definition Linalg.cpp:156
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op)
Definition Linalg.cpp:110
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op)
Definition Linalg.cpp:52
MlirAttribute outputChannel
Definition Linalg.h:47