MLIR  15.0.0git
StructuredOpsUtils.h
Go to the documentation of this file.
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/StringRef.h"
25 
26 namespace mlir {
27 
28 class OpBuilder;
29 
30 /// Tests whether the given maps describe a row major matmul. The test is
31 /// permutation-invariant. Note that this only checks the affine maps from an
32 /// operation, so does not perform any checks on the math being performed within
33 /// the reduction.
34 bool isRowMajorMatmul(ArrayAttr indexingMaps);
35 
36 /// Tests whether the given maps describe a column major matmul. The test is
37 /// permutation-invariant. Note that this only checks the affine maps from an
38 /// operation, so does not perform any checks on the math being performed within
39 /// the reduction.
40 bool isColumnMajorMatmul(ArrayAttr indexingMaps);
41 
42 /// Tests whether the given maps describe a row major batch matmul. The test is
43 /// permutation-invariant. Note that this only checks the affine maps from an
44 /// operation, so does not perform any checks on the math being performed within
45 /// the reduction.
46 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
47 
48 /// Attribute name for the AffineArrayAttr which encodes the relationship
49 /// between a structured op iterators' and its operands.
50 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
51 
52 /// Attribute name for the StrArrayAttr which encodes the type of a structured
53 /// op's iterators.
54 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
55 
56 /// Attribute name for the StrArrayAttr which encodes the distribution type for
57 /// `linalg.tiled_loop`.
58 constexpr StringRef getDistributionTypesAttrName() {
59  return "distribution_types";
60 }
61 
62 /// Attribute name for the StringAttr which encodes an optional documentation
63 /// string of the structured op.
64 constexpr StringRef getDocAttrName() { return "doc"; }
65 
66 /// Attribute name for the StrArrayAttr which encodes the external library
67 /// function that implements the structured op.
68 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
69 
70 /// Attribute name for the StrArrayAttr which encodes the value of strides.
71 constexpr StringRef getStridesAttrName() { return "strides"; }
72 
73 /// Attribute name for the StrArrayAttr which encodes the value of dilations.
74 constexpr StringRef getDilationsAttrName() { return "dilations"; }
75 
76 /// Attribute name for the StrArrayAttr which encodes the value of paddings.
77 constexpr StringRef getPaddingAttrName() { return "padding"; }
78 
79 /// Use to encode that a particular iterator type has parallel semantics.
80 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
81 inline bool isParallelIterator(Attribute attr) {
82  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
83  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
84 }
85 
86 /// Use to encode that a particular iterator type has reduction semantics.
87 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
88 inline bool isReductionIterator(Attribute attr) {
89  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
90  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
91 }
92 
93 /// Use to encode that a particular iterator type has window semantics.
94 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
95 inline bool isWindowIterator(Attribute attr) {
96  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
97  return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
98 }
99 
100 /// Use to encode that a particular iterator type has window semantics.
102  static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
105  return llvm::makeArrayRef(names);
106 }
107 
108 /// Returns the iterator of a certain type.
109 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
110  auto names = getAllIteratorTypeNames();
111  (void)names;
112  assert(llvm::is_contained(names, name));
113  return llvm::count_if(iteratorTypes, [name](Attribute a) {
114  return a.cast<StringAttr>().getValue() == name;
115  });
116 }
117 
118 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
119  unsigned res = 0;
120  for (auto n : getAllIteratorTypeNames())
121  res += getNumIterators(n, iteratorTypes);
122  return res;
123 }
124 
125 /// Typed representation for loop type strings.
127 
128 inline StringRef toString(IteratorType t) {
129  switch (t) {
134  }
135  llvm_unreachable("Unsupported IteratorType");
136 }
137 
138 /// Helper StructuredGenerator class to manipulate and rewrite ops with
139 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
140 /// yet implement the StructuredOpInterface itself.
141 template <typename StructuredOpInterface>
143 public:
145 
146  struct IteratorType {
147  IteratorType(StringRef strRef) : strRef(strRef) {}
148  bool isOfType(Attribute attr) const {
149  auto sAttr = attr.dyn_cast<StringAttr>();
150  return sAttr && sAttr.getValue() == strRef;
151  }
152  StringRef strRef;
153  };
154  struct Par : public IteratorType {
156  };
157  struct Red : public IteratorType {
159  };
160  struct Win : public IteratorType {
162  };
163 
164  StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
165  : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
166  iterators(op.getIteratorTypes()), maps(op.getIndexingMaps()), op(op) {}
167 
169  if (its.size() != iterators.size())
170  return false;
171  for (int i = 0, e = its.size(); i != e; ++i) {
172  if (!its[i].isOfType(iterators[i]))
173  return false;
174  }
175  return true;
176  }
177 
178  bool layout(MapList l) {
179  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
180  return maps == infer(l);
181  }
182 
183 protected:
187  ArrayAttr iterators;
190 };
191 
192 } // namespace mlir
193 
194 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
Include the generated interface declarations.
ArrayRef< StringRef > getAllIteratorTypeNames()
Use to encode that a particular iterator type has window semantics.
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
U cast() const
Definition: Attributes.h:130
U dyn_cast_or_null() const
Definition: Attributes.h:127
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major batch matmul.
bool isParallelIterator(Attribute attr)
constexpr StringRef getPaddingAttrName()
Attribute name for the StrArrayAttr which encodes the value of paddings.
constexpr StringRef getWindowIteratorTypeName()
Use to encode that a particular iterator type has window semantics.
IteratorType
Typed representation for loop type strings.
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op&#39;s iterators.
bool isColumnMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a column major matmul.
constexpr StringRef getDocAttrName()
Attribute name for the StringAttr which encodes an optional documentation string of the structured op...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
constexpr StringRef getDistributionTypesAttrName()
Attribute name for the StrArrayAttr which encodes the distribution type for linalg.tiled_loop.
unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes)
Returns the iterator of a certain type.
constexpr StringRef getDilationsAttrName()
Attribute name for the StrArrayAttr which encodes the value of dilations.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Reduction
SmallVector< AffineMap, 4 > maps
StringRef strRef
bool isReductionIterator(Attribute attr)
constexpr StringRef getLibraryCallAttrName()
Attribute name for the StrArrayAttr which encodes the external library function that implements the s...
constexpr StringRef getIndexingMapsAttrName()
Attribute name for the AffineArrayAttr which encodes the relationship between a structured op iterato...
U dyn_cast() const
Definition: Attributes.h:124
constexpr StringRef getStridesAttrName()
Attribute name for the StrArrayAttr which encodes the value of strides.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
bool isWindowIterator(Attribute attr)
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
StringRef toString(IteratorType t)
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
This class helps build Operations.
Definition: Builders.h:177
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)