MLIR  16.0.0git
StructuredOpsUtils.cpp
Go to the documentation of this file.
1 //===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
14 
15 #include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
16 
17 using namespace mlir;
18 
19 bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
20  if (indexingMaps.size() != 3)
21  return false;
22 
23  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
24  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
25  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
26 
27  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
28  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
29  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
30  return false;
31  }
32 
33  // Extract dimensions for MxK * KxN -> MxN
34  AffineExpr m = map2.getResult(0);
35  AffineExpr n = map2.getResult(1);
36  AffineExpr k = map0.getResult(1);
37  auto *context = indexingMaps.getContext();
38  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
39  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
40  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
41  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
42  return indexingMaps == maps;
43 }
44 
45 bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
46  if (indexingMaps.size() != 3)
47  return false;
48 
49  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
50  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
51  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
52 
53  if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
54  map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
55  map1.getNumInputs() != 3 || map2.getNumInputs() != 3) {
56  return false;
57  }
58 
59  // Extract dimensions for KxM * NxK -> NxM
60  AffineExpr n = map2.getResult(0);
61  AffineExpr m = map2.getResult(1);
62  AffineExpr k = map0.getResult(0);
63  auto *context = indexingMaps.getContext();
64  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context));
65  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context));
66  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
67  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
68  return indexingMaps == maps;
69 }
70 
71 bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
72  if (indexingMaps.size() != 3)
73  return false;
74 
75  auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
76  auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
77  auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
78 
79  if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
80  map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
81  map1.getNumInputs() != 4 || map2.getNumInputs() != 4) {
82  return false;
83  }
84 
85  // Extract dimensions for BxMxK * BxKxN -> BxMxN
86  AffineExpr b = map2.getResult(0);
87  AffineExpr m = map2.getResult(1);
88  AffineExpr n = map2.getResult(2);
89  AffineExpr k = map0.getResult(2);
90  auto *context = indexingMaps.getContext();
91  auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context));
92  auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context));
93  auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context));
94  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
95  return indexingMaps == maps;
96 }
97 
99  ValueRange newOperands) {
101  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
102  op->getAttrs());
103  for (Region &r : op->getRegions())
104  r.cloneInto(state.addRegion(), bvm);
105  return b.create(state);
106 }
107 
109  TypeRange newResultTypes,
110  ValueRange newOperands) {
111  OperationState state(op->getLoc(), op->getName(), newOperands, newResultTypes,
112  op->getAttrs());
113  for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
114  state.addRegion();
115  return b.create(state);
116 }
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:24
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class helps build Operations.
Definition: Builders.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
Include the generated interface declarations.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.