MLIR  17.0.0git
StructuredOpsUtils.cpp
Go to the documentation of this file.
1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "llvm/ADT/StringSet.h"
15 
16 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
17 
18 using namespace mlir;
19 
20 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
21  if (indexingMaps.size() != 3)
22  return false;
23 
24  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
25  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
26  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
27 
28  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
29  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
30  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
31  return false;
32  }
33 
34  // Extract dimensions for MxK * KxN -> MxN
35  AffineExpr m = map2.getResult(0);
36  AffineExpr n = map2.getResult(1);
37  AffineExpr k = map0.getResult(1);
38  auto *context = indexingMaps.getContext();
39  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
40  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
41  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
42  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
43  return indexingMaps == maps;
44 }
45 
46 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
47  if (indexingMaps.size() != 3)
48  return false;
49 
50  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
51  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
52  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
53 
54  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
55  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
56  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
57  return false;
58  }
59 
60  // Extract dimensions for KxM * NxK -> NxM
61  AffineExpr n = map2.getResult(0);
62  AffineExpr m = map2.getResult(1);
63  AffineExpr k = map0.getResult(0);
64  auto *context = indexingMaps.getContext();
65  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
66  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
67  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
68  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
69  return indexingMaps == maps;
70 }
71 
72 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
73  if (indexingMaps.size() != 3)
74  return false;
75 
76  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
77  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
78  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
79 
80  if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
81  map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
82  map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
83  return false;
84  }
85 
86  // Extract dimensions for BxMxK * BxKxN -> BxMxN
87  AffineExpr b = map2.getResult(0);
88  AffineExpr m = map2.getResult(1);
89  AffineExpr n = map2.getResult(2);
90  AffineExpr k = map0.getResult(2);
91  auto *context = indexingMaps.getContext();
92  auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
93  auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
94  auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
95  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
96  return indexingMaps == maps;
97 }
98 
100  ValueRange newOperands) {
101  IRMapping bvm;
102  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
103  op->getAttrs());
104  for (Region &r : op->getRegions())
105  r.cloneInto(state.addRegion(), bvm);
106  return b.create(state);
107 }
108 
110  TypeRange newResultTypes,
111  ValueRange newOperands) {
112  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
113  op->getAttrs());
114  for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
115  state.addRegion();
116  return b.create(state);
117 }
118 
121  llvm::StringSet<> elidedAttrsSet;
122  elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
124  for (auto attr : op->getAttrs()) {
125  if (elidedAttrsSet.count(attr.getName()))
126  continue;
127  attrs.push_back(attr);
128  }
129  return attrs;
130 }
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class helps build Operations.
Definition: Builders.h:202
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:540
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
Include the generated interface declarations.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.