MLIR  18.0.0git
StructuredOpsUtils.cpp
Go to the documentation of this file.
1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "llvm/ADT/StringSet.h"
15 
16 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17 
18 using namespace mlir;
19 
20 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21  if (indexingMaps.size() != 3)
22  return false;
23 
24  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
25  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
26  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
27 
28  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31  return false;
32  }
33 
34  // Extract dimensions for MxK * KxN -> MxN
35  AffineExpr m = map2.getResult(0);
36  AffineExpr n = map2.getResult(1);
37  AffineExpr k = map0.getResult(1);
38  auto *context = indexingMaps.getContext();
39  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43  return indexingMaps == maps;
44 }
45 
46 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47  if (indexingMaps.size() != 3)
48  return false;
49 
50  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
51  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
52  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
53 
54  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57  return false;
58  }
59 
60  // Extract dimensions for KxM * NxK -> NxM
61  AffineExpr n = map2.getResult(0);
62  AffineExpr m = map2.getResult(1);
63  AffineExpr k = map0.getResult(0);
64  auto *context = indexingMaps.getContext();
65  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69  return indexingMaps == maps;
70 }
71 
72 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73  if (indexingMaps.size() != 3)
74  return false;
75 
76  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
77  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
78  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
79 
80  if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81  map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82  map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83  return false;
84  }
85 
86  // Extract dimensions for BxMxK * BxKxN -> BxMxN
87  AffineExpr b = map2.getResult(0);
88  AffineExpr m = map2.getResult(1);
89  AffineExpr n = map2.getResult(2);
90  AffineExpr k = map0.getResult(2);
91  auto *context = indexingMaps.getContext();
92  auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93  auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94  auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96  return indexingMaps == maps;
97 }
98 
99 bool mlir::isVecmat(ArrayAttr indexingMaps) {
100  if (indexingMaps.size() != 3)
101  return false;
102  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
103  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
104  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
105 
106  if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
107  map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
108  map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
109  return false;
110  }
111 
112  // Extract dimensions for K * KxN -> N
113  AffineExpr k = map0.getResult(0);
114  AffineExpr n = map2.getResult(0);
115  auto *context = indexingMaps.getContext();
116  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
117  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context));
118  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
119  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
120  return indexingMaps == maps;
121 }
122 
123 bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
124  if (indexingMaps.size() != 3)
125  return false;
126  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
127  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
128  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
129 
130  if (map0.getNumResults() != 2 || map1.getNumResults() != 3 ||
131  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
132  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
133  return false;
134  }
135 
136  // Extract dimensions for B*K * B*K*N -> B*N
137  AffineExpr b = map0.getResult(0);
138  AffineExpr k = map0.getResult(1);
139  AffineExpr n = map2.getResult(1);
140  auto *context = indexingMaps.getContext();
141  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
142  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context));
143  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
144  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
145  return indexingMaps == maps;
146 }
147 
148 bool mlir::isMatvec(ArrayAttr indexingMaps) {
149  if (indexingMaps.size() != 3)
150  return false;
151  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
152  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
153  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
154 
155  if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
156  map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
157  map1.getNumInputs() != 2 || map2.getNumInputs() != 2) {
158  return false;
159  }
160 
161  // Extract dimensions for N*K * K -> N
162  AffineExpr k = map1.getResult(0);
163  AffineExpr n = map2.getResult(0);
164  auto *context = indexingMaps.getContext();
165  auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context));
166  auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context));
167  auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context));
168  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
169  return indexingMaps == maps;
170 }
171 
172 bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
173  if (indexingMaps.size() != 3)
174  return false;
175  AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
176  AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
177  AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
178 
179  if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
180  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
181  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
182  return false;
183  }
184 
185  // Extract dimensions for B*N*K * B*K -> B*N
186  AffineExpr b = map0.getResult(0);
187  AffineExpr k = map1.getResult(1);
188  AffineExpr n = map2.getResult(1);
189  auto *context = indexingMaps.getContext();
190  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context));
191  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context));
192  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context));
193  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
194  return indexingMaps == maps;
195 }
196 
198  ValueRange newOperands) {
199  IRMapping bvm;
200  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
201  op->getAttrs());
202  for (Region &r : op->getRegions())
203  r.cloneInto(state.addRegion(), bvm);
204  return b.create(state);
205 }
206 
208  TypeRange newResultTypes,
209  ValueRange newOperands) {
210  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
211  op->getAttrs());
212  for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
213  state.addRegion();
214  return b.create(state);
215 }
216 
219  llvm::StringSet<> elidedAttrsSet;
220  elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
222  for (auto attr : op->getAttrs()) {
223  if (elidedAttrsSet.count(attr.getName()))
224  continue;
225  attrs.push_back(attr);
226  }
227  return attrs;
228 }
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:382
unsigned getNumInputs() const
Definition: AffineMap.cpp:383
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
Include the generated interface declarations.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
bool isBatchMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch matrix vector multiplication.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a matrix vector multiplication.
bool isBatchVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch vector matrix multiplication.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
This represents an operation in an abstracted form, suitable for use with the builder APIs.