MLIR  20.0.0git
DistributionUtils.h
Go to the documentation of this file.
1 //===- DistributionUtils.h - Distribution Utilities -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
10 #define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
11 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/Value.h"
17 
18 namespace mlir::gpu {
19 struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
22 
23  virtual LogicalResult
24  matchAndRewrite(WarpExecuteOnLane0Op op,
25  PatternRewriter &rewriter) const override = 0;
26 
27 protected:
28  /// Return a value yielded by `warpOp` which statifies the filter lamdba
29  /// condition and is not dead.
30  OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
31  llvm::function_ref<bool(Operation *)> fn) const;
32 
33  /// Helper to create a new WarpExecuteOnLane0Op with different signature.
34  WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
35  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
36  ValueRange newYieldedValues, TypeRange newReturnTypes) const;
37 
38  /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
39  /// `indices` return the index of each new output.
40  WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
41  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
42  ValueRange newYieldedValues, TypeRange newReturnTypes,
43  SmallVector<size_t> &indices) const;
44 
45  /// Delinearize the given `laneId` into multiple dimensions, where each
46  /// dimension's size is determined by `originalShape` and `distributedShape`
47  /// together. This function expects the total numbers of threads needed for
48  /// distribution is equal to `warpSize`. Returns true and updates
49  /// `delinearizedIds` if so.
50  bool delinearizeLaneId(OpBuilder &builder, Location loc,
51  ArrayRef<int64_t> originalShape,
52  ArrayRef<int64_t> distributedShape, int64_t warpSize,
53  Value laneId,
54  SmallVectorImpl<Value> &delinearizedIds) const;
55 };
56 
57 } // namespace mlir::gpu
58 
59 #endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.