MLIR  18.0.0git
StructuredOpsUtils.h
Go to the documentation of this file.
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/TypeRange.h"
24 #include "mlir/Support/LLVM.h"
25 
26 // Pull in all enum type definitions and utility function declarations.
27 #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
28 
29 namespace mlir {
30 
31 class OpBuilder;
32 class RewriterBase;
33 
34 /// Tests whether the given maps describe a row major matmul. The test is
35 /// permutation-invariant. Note that this only checks the affine maps from an
36 /// operation, so does not perform any checks on the math being performed within
37 /// the reduction.
38 bool isRowMajorMatmul(ArrayAttr indexingMaps);
39 
40 /// Tests whether the given maps describe a column major matmul. The test is
41 /// permutation-invariant. Note that this only checks the affine maps from an
42 /// operation, so does not perform any checks on the math being performed within
43 /// the reduction.
44 bool isColumnMajorMatmul(ArrayAttr indexingMaps);
45 
46 /// Tests whether the given maps describe a row major batch matmul. The test is
47 /// permutation-invariant. Note that this only checks the affine maps from an
48 /// operation, so does not perform any checks on the math being performed within
49 /// the reduction.
50 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
51 
52 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
54  utils::IteratorType iteratorTypeName,
56  for (const auto &en : llvm::enumerate(iteratorTypes)) {
57  if (en.value() == iteratorTypeName)
58  res.push_back(en.index());
59  }
60 }
61 
62 /// Helper StructuredGenerator class to manipulate and rewrite ops with
63 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
64 /// yet implement the StructuredOpInterface itself.
65 template <typename StructuredOpInterface, typename IteratorTypeT>
67 public:
69 
70  struct IteratorType {
71  IteratorType(IteratorTypeT iter) : iter(iter) {}
72  bool isOfType(IteratorTypeT expectedIter) const {
73  return expectedIter == iter;
74  }
75  IteratorTypeT iter;
76  };
77  struct Par : public IteratorType {
78  Par() : IteratorType(IteratorTypeT::parallel) {}
79  };
80  struct Red : public IteratorType {
81  Red() : IteratorType(IteratorTypeT::reduction) {}
82  };
83 
84  StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op)
85  : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
86  iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
87  op(op) {}
88 
90  if (its.size() != iterators.size())
91  return false;
92  for (int i = 0, e = its.size(); i != e; ++i) {
93  if (!its[i].isOfType(iterators[i]))
94  return false;
95  }
96  return true;
97  }
98 
99  bool layout(MapList l) {
100  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
101  return maps == infer(l);
102  }
103 
104 protected:
111 };
112 
113 // Clone the current operation with the operands. This is used to abstract away
114 // the optional underlying region creation.
115 // Note: this is a true builder that notifies the OpBuilder listener.
116 Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
117  ValueRange newOperands);
118 template <typename OpT>
119 OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes,
120  ValueRange newOperands) {
121  return cast<OpT>(clone(b, op.getOperation(), newResultTypes, newOperands));
122 }
123 
124 // Clone the current operation with the operands but leave the regions empty.
125 // Note: this is a true builder that notifies the OpBuilder listener.
126 Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
127  TypeRange newResultTypes,
128  ValueRange newOperands);
129 
130 // Get the list of attributes associated with the op, ignoring
131 // those with the provided name.
132 SmallVector<NamedAttribute>
133 getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
134 
135 } // namespace mlir
136 
137 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
static MLIRContext * getContext(OpFoldResult val)
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:255
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
SmallVector< AffineMap, 4 > maps
bool iters(ArrayRef< IteratorType > its)
SmallVector< IteratorTypeT > iterators
StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void findPositionsOfType(ArrayRef< utils::IteratorType > iteratorTypes, utils::IteratorType iteratorTypeName, SmallVectorImpl< unsigned > &res)
Return positions in iteratorTypes that match iteratorTypeName.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool isOfType(IteratorTypeT expectedIter) const