MLIR 22.0.0git
StructuredOpsUtils.h
Go to the documentation of this file.
1//===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file define utilities that operate on builtin types and are
10// useful across multiple dialects that use structured ops abstractions. These
11// abstractions consist of define custom operations that encode and transport
12// information about their semantics (e.g. type of iterators like parallel,
13// reduction, etc..) as attributes.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18#define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19
20#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/TypeRange.h"
24#include "mlir/Support/LLVM.h"
25
26// Pull in all enum type definitions and utility function declarations.
27#include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
28
29namespace mlir {
30
31class OpBuilder;
32class RewriterBase;
33
34/// Tests whether the given maps describe a row major matmul. The test is
35/// permutation-invariant. Note that this only checks the affine maps from an
36/// operation, so does not perform any checks on the math being performed within
37/// the reduction.
38bool isRowMajorMatmul(ArrayAttr indexingMaps);
39
40/// Tests whether the given maps describe a column major matmul. The test is
41/// permutation-invariant. Note that this only checks the affine maps from an
42/// operation, so does not perform any checks on the math being performed within
43/// the reduction.
44bool isColumnMajorMatmul(ArrayAttr indexingMaps);
45
46/// Tests whether the given maps describe a row major batch matmul. The test is
47/// permutation-invariant. Note that this only checks the affine maps from an
48/// operation, so does not perform any checks on the math being performed within
49/// the reduction.
50bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
51
52/// Tests whether the given maps describe a vector matrix multiplication. The
53/// test is permutation-invariant. Note that this only checks the affine maps
54/// from an operation, so does not perform any checks on the math being
55/// performed within the reduction.
56bool isVecmat(ArrayAttr indexingMaps);
57
58/// Tests whether the given maps describe a batch vector matrix multiplication.
59/// The test is permutation-invariant. Note that this only checks the affine
60/// maps from an operation, so does not perform any checks on the math being
61/// performed within the reduction.
62bool isBatchVecmat(ArrayAttr indexingMaps);
63
64/// Tests whether the given maps describe a matrix vector multiplication. The
65/// test is permutation-invariant. Note that this only checks the affine maps
66/// from an operation, so does not perform any checks on the math being
67/// performed within the reduction.
68bool isMatvec(ArrayAttr indexingMaps);
69
70/// Tests whether the given maps describe a batch matrix vector multiplication.
71/// The test is permutation-invariant. Note that this only checks the affine
72/// maps from an operation, so does not perform any checks on the math being
73/// performed within the reduction.
74bool isBatchMatvec(ArrayAttr indexingMaps);
75
76/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
78 utils::IteratorType iteratorTypeName,
80 for (const auto &en : llvm::enumerate(iteratorTypes)) {
81 if (en.value() == iteratorTypeName)
82 res.push_back(en.index());
83 }
84}
85
86/// Helper StructuredGenerator class to manipulate and rewrite ops with
87/// `StructuredOpInterface`. This is templated for now because VectorOps do not
88/// yet implement the StructuredOpInterface itself.
89template <typename StructuredOpInterface, typename IteratorTypeT>
91public:
93
94 struct IteratorType {
95 IteratorType(IteratorTypeT iter) : iter(iter) {}
96 bool isOfType(IteratorTypeT expectedIter) const {
97 return expectedIter == iter;
98 }
99 IteratorTypeT iter;
100 };
101 struct Par : public IteratorType {
102 Par() : IteratorType(IteratorTypeT::parallel) {}
103 };
104 struct Red : public IteratorType {
105 Red() : IteratorType(IteratorTypeT::reduction) {}
106 };
107
109 : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
110 iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
111 op(op) {}
112
114 if (its.size() != iterators.size())
115 return false;
116 for (int i = 0, e = its.size(); i != e; ++i) {
117 if (!its[i].isOfType(iterators[i]))
118 return false;
119 }
120 return true;
121 }
122
123 bool layout(MapList l) {
124 auto infer = [&](MapList m) {
126 };
127 return maps == infer(l);
128 }
129
130protected:
137};
138
139// Clone the current operation with the operands. This is used to abstract away
140// the optional underlying region creation.
141// Note: this is a true builder that notifies the OpBuilder listener.
142Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
143 ValueRange newOperands);
144template <typename OpT>
145OpT clone(OpBuilder &b, OpT op, TypeRange newResultTypes,
146 ValueRange newOperands) {
147 return cast<OpT>(clone(b, op.getOperation(), newResultTypes, newOperands));
148}
149
150// Clone the current operation with the operands but leave the regions empty.
151// Note: this is a true builder that notifies the OpBuilder listener.
152Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
153 TypeRange newResultTypes,
154 ValueRange newOperands);
155
156// Get the list of attributes associated with the op, ignoring
157// those with the provided name.
159getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
160
161} // namespace mlir
162
163#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
SmallVector< AffineMap, 4 > maps
bool iters(ArrayRef< IteratorType > its)
SmallVector< IteratorTypeT > iterators
StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op)
ArrayRef< ArrayRef< AffineExpr > > MapList
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Include the generated interface declarations.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
bool isBatchMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch matrix vector multiplication.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isMatvec(ArrayAttr indexingMaps)
Tests whether the given maps describe a matrix vector multiplication.
bool isBatchVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a batch vector matrix multiplication.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
void findPositionsOfType(ArrayRef< utils::IteratorType > iteratorTypes, utils::IteratorType iteratorTypeName, SmallVectorImpl< unsigned > &res)
Return positions in iteratorTypes that match iteratorTypeName.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool isOfType(IteratorTypeT expectedIter) const