MLIR  16.0.0git
StructuredOpsUtils.h
Go to the documentation of this file.
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/StringRef.h"
25 
26 // Pull in all enum type definitions and utility function declarations.
27 #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
28 
29 namespace mlir {
30 
31 class OpBuilder;
32 
33 /// Tests whether the given maps describe a row major matmul. The test is
34 /// permutation-invariant. Note that this only checks the affine maps from an
35 /// operation, so does not perform any checks on the math being performed within
36 /// the reduction.
37 bool isRowMajorMatmul(ArrayAttr indexingMaps);
38 
39 /// Tests whether the given maps describe a column major matmul. The test is
40 /// permutation-invariant. Note that this only checks the affine maps from an
41 /// operation, so does not perform any checks on the math being performed within
42 /// the reduction.
43 bool isColumnMajorMatmul(ArrayAttr indexingMaps);
44 
45 /// Tests whether the given maps describe a row major batch matmul. The test is
46 /// permutation-invariant. Note that this only checks the affine maps from an
47 /// operation, so does not perform any checks on the math being performed within
48 /// the reduction.
49 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
50 
51 /// Attribute name for the AffineArrayAttr which encodes the relationship
52 /// between a structured op iterators' and its operands.
53 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
54 
55 /// Attribute name for the StrArrayAttr which encodes the type of a structured
56 /// op's iterators.
57 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
58 
59 /// Attribute name for the StrArrayAttr which encodes the distribution type for
60 /// `linalg.tiled_loop`.
61 constexpr StringRef getDistributionTypesAttrName() {
62  return "distribution_types";
63 }
64 
65 /// Attribute name for the StringAttr which encodes an optional documentation
66 /// string of the structured op.
67 constexpr StringRef getDocAttrName() { return "doc"; }
68 
69 /// Attribute name for the StrArrayAttr which encodes the external library
70 /// function that implements the structured op.
71 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
72 
73 /// Attribute name for the StrArrayAttr which encodes the value of strides.
74 constexpr StringRef getStridesAttrName() { return "strides"; }
75 
76 /// Attribute name for the StrArrayAttr which encodes the value of dilations.
77 constexpr StringRef getDilationsAttrName() { return "dilations"; }
78 
79 /// Attribute name for the StrArrayAttr which encodes the value of paddings.
80 constexpr StringRef getPaddingAttrName() { return "padding"; }
81 
82 /// Use to encode that a particular iterator type has parallel semantics.
83 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
84 
85 /// Use to encode that a particular iterator type has reduction semantics.
86 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
87 
88 /// Use to encode that a particular iterator type has window semantics.
89 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
90 
91 /// Use to encode that a particular iterator type has window semantics.
93  static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
96  return llvm::makeArrayRef(names);
97 }
98 
99 /// Returns the iterator of a certain type.
100 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
101  auto names = getAllIteratorTypeNames();
102  (void)names;
103  assert(llvm::is_contained(names, name));
104  return llvm::count_if(iteratorTypes, [name](Attribute a) {
105  return a.cast<StringAttr>().getValue() == name;
106  });
107 }
108 
109 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
110  unsigned res = 0;
111  for (auto n : getAllIteratorTypeNames())
112  res += getNumIterators(n, iteratorTypes);
113  return res;
114 }
115 
116 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
117 inline void findPositionsOfType(ArrayAttr iteratorTypes,
118  StringRef iteratorTypeName,
120  for (const auto &en :
121  llvm::enumerate(iteratorTypes.getAsValueRange<StringAttr>())) {
122  if (en.value() == iteratorTypeName)
123  res.push_back(en.index());
124  }
125 }
126 
127 /// Helper StructuredGenerator class to manipulate and rewrite ops with
128 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
129 /// yet implement the StructuredOpInterface itself.
130 template <typename StructuredOpInterface>
132 public:
134 
135  struct IteratorType {
136  IteratorType(StringRef strRef) : strRef(strRef) {}
137  bool isOfType(StringRef typeName) const { return typeName == strRef; }
138  StringRef strRef;
139  };
140  struct Par : public IteratorType {
142  };
143  struct Red : public IteratorType {
145  };
146  struct Win : public IteratorType {
148  };
149 
150  StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
151  : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
152  iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
153  op(op) {}
154 
156  if (its.size() != iterators.size())
157  return false;
158  for (int i = 0, e = its.size(); i != e; ++i) {
159  if (!its[i].isOfType(iterators[i]))
160  return false;
161  }
162  return true;
163  }
164 
165  bool layout(MapList l) {
166  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
167  return maps == infer(l);
168  }
169 
170 protected:
177 };
178 
179 } // namespace mlir
180 
181 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
Include the generated interface declarations.
bool isOfType(StringRef typeName) const
ArrayRef< StringRef > getAllIteratorTypeNames()
Use to encode that a particular iterator type has window semantics.
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
U cast() const
Definition: Attributes.h:136
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
constexpr StringRef getPaddingAttrName()
Attribute name for the StrArrayAttr which encodes the value of paddings.
constexpr StringRef getWindowIteratorTypeName()
Use to encode that a particular iterator type has window semantics.
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op&#39;s iterators.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
constexpr StringRef getDocAttrName()
Attribute name for the StringAttr which encodes an optional documentation string of the structured op...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void findPositionsOfType(ArrayAttr iteratorTypes, StringRef iteratorTypeName, SmallVectorImpl< unsigned > &res)
Return positions in iteratorTypes that match iteratorTypeName.
constexpr StringRef getDistributionTypesAttrName()
Attribute name for the StrArrayAttr which encodes the distribution type for linalg.tiled_loop.
SmallVector< StringRef > iterators
unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes)
Returns the iterator of a certain type.
constexpr StringRef getDilationsAttrName()
Attribute name for the StrArrayAttr which encodes the value of dilations.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:232
SmallVector< AffineMap, 4 > maps
constexpr StringRef getLibraryCallAttrName()
Attribute name for the StrArrayAttr which encodes the external library function that implements the s...
constexpr StringRef getIndexingMapsAttrName()
Attribute name for the AffineArrayAttr which encodes the relationship between a structured op iterato...
constexpr StringRef getStridesAttrName()
Attribute name for the StrArrayAttr which encodes the value of strides.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
This class helps build Operations.
Definition: Builders.h:196
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)