MLIR 22.0.0git
StructuredOpsUtils.cpp
Go to the documentation of this file.
1//===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/AffineMap.h"
11#include "mlir/IR/Builders.h"
13#include "mlir/IR/IRMapping.h"
14#include "llvm/ADT/StringSet.h"
15
16#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17
18using namespace mlir;
19
21 if (indexingMaps.size() != 3)
22 return false;
23
24 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27
28 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31 return false;
32 }
33
34 // Extract dimensions for MxK * KxN -> MxN
35 AffineExpr m = map2.getResult(0);
36 AffineExpr n = map2.getResult(1);
37 AffineExpr k = map0.getResult(1);
38 auto *context = indexingMaps.getContext();
39 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43 return indexingMaps == maps;
44}
45
47 if (indexingMaps.size() != 3)
48 return false;
49
50 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53
54 if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57 return false;
58 }
59
60 // Extract dimensions for KxM * NxK -> NxM
61 AffineExpr n = map2.getResult(0);
62 AffineExpr m = map2.getResult(1);
63 AffineExpr k = map0.getResult(0);
64 auto *context = indexingMaps.getContext();
65 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69 return indexingMaps == maps;
70}
71
73 if (indexingMaps.size() != 3)
74 return false;
75
76 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79
80 if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81 map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82 map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83 return false;
84 }
85
86 // Extract dimensions for BxMxK * BxKxN -> BxMxN
87 AffineExpr b = map2.getResult(0);
88 AffineExpr m = map2.getResult(1);
89 AffineExpr n = map2.getResult(2);
90 AffineExpr k = map0.getResult(2);
91 auto *context = indexingMaps.getContext();
92 auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93 auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94 auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96 return indexingMaps == maps;
97}
98
99bool mlir::isVecmat(ArrayAttr indexingMaps) {
100 if (indexingMaps.size() != 3)
101 return false;
102 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105
106 if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109 return false;
110 }
111
112 // Extract dimensions for K * KxN -> N
113 AffineExpr k = map0.getResult(0);
114 AffineExpr n = map2.getResult(0);
115 auto *context = indexingMaps.getContext();
116 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120 return indexingMaps == maps;
121}
122
123bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124 if (indexingMaps.size() != 3)
125 return false;
126 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129
130 if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133 return false;
134 }
135
136 // Extract dimensions for B*K * B*K*N -> B*N
137 AffineExpr b = map0.getResult(0);
138 AffineExpr k = map0.getResult(1);
139 AffineExpr n = map2.getResult(1);
140 auto *context = indexingMaps.getContext();
141 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145 return indexingMaps == maps;
146}
147
148bool mlir::isMatvec(ArrayAttr indexingMaps) {
149 if (indexingMaps.size() != 3)
150 return false;
151 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
152 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
153 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
154
155 if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
156 map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
157 map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
158 return false;
159 }
160
161 // Extract dimensions for N*K * K -> N
162 AffineExpr k = map1.getResult(0);
163 AffineExpr n = map2.getResult(0);
164 auto *context = indexingMaps.getContext();
165 auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
166 auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
167 auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
168 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169 return indexingMaps == maps;
170}
171
172bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
173 if (indexingMaps.size() != 3)
174 return false;
175 AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
176 AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
177 AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
178
179 if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
180 map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
181 map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
182 return false;
183 }
184
185 // Extract dimensions for B*N*K * B*K -> B*N
186 AffineExpr b = map0.getResult(0);
187 AffineExpr k = map1.getResult(1);
188 AffineExpr n = map2.getResult(1);
189 auto *context = indexingMaps.getContext();
190 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
191 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
192 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
193 auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
194 return indexingMaps == maps;
195}
196
198 ValueRange newOperands) {
199 IRMapping bvm;
200 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201 op->getAttrs());
202 for (Region &r : op->getRegions()) {
203 Region *newRegion = state.addRegion();
204 b.cloneRegionBefore(r, *newRegion, newRegion->begin(), bvm);
205 }
206 return b.create(state);
207}
208
210 TypeRange newResultTypes,
211 ValueRange newOperands) {
212 OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
213 op->getAttrs());
214 for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
215 state.addRegion();
216 return b.create(state);
217}
218
221 llvm::StringSet<> elidedAttrsSet;
222 elidedAttrsSet.insert_range(elidedAttrs);
224 for (auto attr : op->getAttrs()) {
225 if (elidedAttrsSet.count(attr.getName()))
226 continue;
227 attrs.push_back(attr);
228 }
229 return attrs;
230}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Base type for affine expression.
Definition AffineExpr.h:68
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Include the generated interface declarations.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
bool isBatchMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch matrix vector multiplication.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a matrix vector multiplication.
bool isBatchVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch vector matrix multiplication.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.