MLIR  14.0.0git
LinalgToSPIRV.cpp
Go to the documentation of this file.
1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Utilities
24 //===----------------------------------------------------------------------===//
25 
26 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
27 /// location invocation ID. This function will create necessary operations with
28 /// `builder` at the proper region containing `op`.
29 static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
30  Location loc, OpBuilder *builder) {
31  assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
33  op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
34  Type xType = invocation.getType().cast<ShapedType>().getElementType();
35  return builder->create<spirv::CompositeExtractOp>(
36  loc, xType, invocation, builder->getI32ArrayAttr({dim}));
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Reduction (single workgroup)
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 
45 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
46 /// that the linalg.generic op is performing reduction with a workload size that
47 /// can fit in one workgroup.
48 struct SingleWorkgroupReduction final
49  : public OpConversionPattern<linalg::GenericOp> {
51 
52  /// Matches the given linalg.generic op as performing reduction and returns
53  /// the binary op kind if successful.
55  matchAsPerformingReduction(linalg::GenericOp genericOp);
56 
58  matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
59  ConversionPatternRewriter &rewriter) const override;
60 };
61 
62 } // namespace
63 
65 SingleWorkgroupReduction::matchAsPerformingReduction(
66  linalg::GenericOp genericOp) {
67  Operation *op = genericOp.getOperation();
68 
69  // Make sure the linalg.generic is working on memrefs.
70  if (!genericOp.hasBufferSemantics())
71  return llvm::None;
72 
73  // Make sure this is reduction with one input and one output.
74  if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
75  return llvm::None;
76 
77  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
78  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
79 
80  // Make sure the original input has one dimension.
81  if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
82  return llvm::None;
83  // Make sure the original output has one element.
84  if (!originalOutputType.hasStaticShape() ||
85  originalOutputType.getNumElements() != 1)
86  return llvm::None;
87 
88  if (!genericOp.hasSingleReductionLoop())
89  return llvm::None;
90 
91  if (genericOp.indexing_maps().getValue().size() != 2)
92  return llvm::None;
93 
94  // TODO: create utility functions for these checks in Linalg
95  // and use them.
96  auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
97  auto outputMap =
98  genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
99  // The indexing map for the input should be `(i) -> (i)`.
100  if (inputMap.getValue() !=
101  AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
102  return llvm::None;
103  // The indexing map for the input should be `(i) -> (0)`.
104  if (outputMap.getValue() !=
106  return llvm::None;
107 
109 }
110 
111 LogicalResult SingleWorkgroupReduction::matchAndRewrite(
112  linalg::GenericOp genericOp, OpAdaptor adaptor,
113  ConversionPatternRewriter &rewriter) const {
114  Operation *op = genericOp.getOperation();
115  auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
116  auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
117 
118  auto binaryOpKind = matchAsPerformingReduction(genericOp);
119  if (!binaryOpKind)
120  return failure();
121 
122  // Query the shader interface for local workgroup size to make sure the
123  // invocation configuration fits with the input memref's shape.
125  if (!localSize)
126  return failure();
127 
128  if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
129  return failure();
130  if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
131  [](const APInt &size) { return !size.isOneValue(); }))
132  return failure();
133 
134  // TODO: Query the target environment to make sure the current
135  // workload fits in a local workgroup.
136 
137  Value convertedInput = adaptor.getOperands()[0];
138  Value convertedOutput = adaptor.getOperands()[1];
139  Location loc = genericOp.getLoc();
140 
141  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
142  auto indexType = typeConverter->getIndexType();
143 
144  // Get the invocation ID.
145  Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
146  &rewriter);
147 
148  // TODO: Load to Workgroup storage class first.
149 
150 
151  // Get the input element accessed by this invocation.
152  Value inputElementPtr = spirv::getElementPtr(
153  *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
154  Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
155 
156  // Perform the group reduction operation.
157  Value groupOperation;
158 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
159  case linalg::RegionMatcher::BinaryOpKind::opKind: { \
160  groupOperation = rewriter.create<spirv::spvOp>( \
161  loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
162  spirv::GroupOperation::Reduce, inputElement, \
163  /*cluster_size=*/nullptr); \
164  } break
165  switch (*binaryOpKind) {
166  CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
167  }
168 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
169 
170  // Get the output element accessed by this reduction.
171  Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
172  SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
173  Value outputElementPtr =
174  spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
175  zeroIndices, loc, rewriter);
176 
177  // Write out the final reduction result. This should be only conducted by one
178  // invocation. We use spv.GroupNonUniformElect to find the invocation with the
179  // lowest ID.
180  //
181  // ```
182  // if (spv.GroupNonUniformElect) { output = ... }
183  // ```
184 
185  Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
186  loc, spirv::Scope::Subgroup);
187 
188  auto createAtomicOp = [&](OpBuilder &builder) {
189 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
190  case linalg::RegionMatcher::BinaryOpKind::opKind: { \
191  builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
192  spirv::MemorySemantics::AcquireRelease, \
193  groupOperation); \
194  } break
195  switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
196 #undef CREATE_ATOMIC_BIN_OP
197  };
198 
199  spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
200 
201  rewriter.eraseOp(genericOp);
202  return success();
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // Pattern population
207 //===----------------------------------------------------------------------===//
208 
210  RewritePatternSet &patterns) {
211  patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
212 }
Include the generated interface declarations.
DenseIntElementsAttr lookupLocalWorkGroupSize(Operation *op)
Queries the local workgroup size from entry point ABI on the nearest function-like op containing the ...
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
#define CREATE_ATOMIC_BIN_OP(opKind, spvOp)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:516
Value getOperand(unsigned idx)
Definition: Operation.h:219
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, Location loc, OpBuilder *builder)
Returns a Value containing the dim-th dimension&#39;s size of SPIR-V location invocation ID...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder)
Returns the value for the given builtin variable.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Linalg ops to SPIR-V ops...
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:215
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr...
#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp)
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class helps build Operations.
Definition: Builders.h:177
iterator begin() const
Iterator access to the integer element values.
static Optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
Definition: Utils.cpp:97
MLIRContext * getContext() const
Definition: PatternMatch.h:906
Type conversion from builtin types to SPIR-V types for shader interface.
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250