29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
39 using llvm::MapVector;
41 #define DEBUG_TYPE "linalg-promotion"
50 std::optional<unsigned> alignment = std::nullopt) {
53 IntegerAttr alignmentAttr;
54 if (alignment.has_value())
58 if (
options.memorySpace.has_value())
59 memorySpaceAttr = *
options.memorySpace;
63 auto staticBufferType =
76 auto dynamicBufferType =
81 b.
create<arith::ConstantIndexOp>(width), allocSize);
83 return b.
create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
84 return b.
create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
94 std::optional<unsigned> alignment,
DataLayout &layout) {
95 ShapedType viewType = subView.getType();
97 auto zero = b.
create<arith::ConstantIndexOp>(0);
98 auto one = b.
create<arith::ConstantIndexOp>(1);
101 if (
options.memorySpace.has_value())
102 memorySpaceAttr = *
options.memorySpace;
104 Value allocSize = one;
106 allocSize = b.
createOrFold<arith::MulIOp>(allocSize, size.value());
110 ShapedType::kDynamic);
112 auto viewMemRefType =
MemRefType::get(dynSizes, viewType.getElementType());
116 boundingSubViewSize);
127 auto viewOp = cast<memref::ViewOp>(fullLocalView.
getDefiningOp());
128 b.
create<memref::DeallocOp>(viewOp.getSource().getLoc(),
140 struct LinalgOpInstancePromotionOptions {
141 LinalgOpInstancePromotionOptions(LinalgOp op,
144 MapVector<int64_t, Value> subViews;
156 std::optional<unsigned> alignment;
160 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
162 : subViews(), alignment(
options.alignment) {
163 assert(linalgOp.hasBufferSemantics() &&
"revisit usage of shaped operand");
164 auto vUseFullTileBuffers =
165 options.useFullTileBuffers.value_or(llvm::SmallBitVector());
166 vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
167 options.useFullTileBuffersDefault);
169 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
170 int64_t operandNumber = opOperand.getOperandNumber();
171 if (
options.operandsToPromote &&
172 !
options.operandsToPromote->count(operandNumber))
174 Operation *op = opOperand.get().getDefiningOp();
175 if (
auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
176 subViews[operandNumber] = sv;
177 useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
182 allocationFn = *
options.allocationFn;
184 allocationFn = [&](
OpBuilder &b, memref::SubViewOp subViewOp,
188 boundingSubViewSize, alignment, layout);
193 deallocationFn = *
options.deallocationFn;
204 b.
create<memref::CopyOp>(loc, src, dst);
207 copyInFn = (
options.copyInFn ? *(
options.copyInFn) : defaultCopyCallBack);
208 copyOutFn = (
options.copyOutFn ? *(
options.copyOutFn) : defaultCopyCallBack);
231 auto viewType = subView.getType();
232 auto rank = viewType.getRank();
235 fullSizes.reserve(rank);
236 partialSizes.reserve(rank);
237 llvm::SmallBitVector droppedDims = subView.getDroppedDims();
238 int64_t resultDimIdx = 0;
239 for (
const auto &en :
llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
240 if (droppedDims[en.index()])
242 auto rangeValue = en.value();
245 LLVM_DEBUG(llvm::dbgs() <<
"Extract tightest: " << rangeValue.size <<
"\n");
247 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
250 Value materializedSize =
253 ValueBoundsConstraintSet::computeConstantBound(
254 presburger::BoundType::UB, materializedSize, std::nullopt,
260 LLVM_DEBUG(llvm::dbgs() <<
"Extracted tightest: " << size <<
"\n");
261 fullSizes.push_back(size);
262 partialSizes.push_back(
263 b.
createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
268 std::optional<Value> fullLocalView =
269 allocationFn(b, subView, fullSizes, layout);
274 auto partialLocalView = b.
createOrFold<memref::SubViewOp>(
275 loc, *fullLocalView, zeros, partialSizes, ones);
285 MapVector<int64_t, PromotionInfo> promotionInfoMap;
287 for (
auto v :
options.subViews) {
288 memref::SubViewOp subView =
289 cast<memref::SubViewOp>(v.second.getDefiningOp());
292 if (
failed(promotionInfo))
294 promotionInfoMap[v.first] = *promotionInfo;
297 if (!
options.useFullTileBuffers[v.second])
299 Type subviewEltType = subView.getType().getElementType();
305 .Case([&](IntegerType t) {
308 .Case([&](ComplexType t) {
310 if (
auto et = dyn_cast<FloatType>(t.getElementType()))
312 else if (
auto et = cast<IntegerType>(t.getElementType()))
314 return b.
create<complex::CreateOp>(t, tmp, tmp);
316 .Default([](
auto) {
return Value(); });
319 b.
create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
323 for (
auto v :
options.subViews) {
324 auto info = promotionInfoMap.find(v.first);
325 if (info == promotionInfoMap.end())
328 b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
329 info->second.partialLocalView)))
332 return promotionInfoMap;
338 assert(op.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
342 if (
failed(promotedBuffersAndViews) ||
343 promotedBuffersAndViews->size() !=
options.subViews.size())
352 writebackViews.reserve(promotedBuffersAndViews->size());
354 int64_t operandNumber = opOperand.getOperandNumber();
355 if (
options.subViews.count(operandNumber) != 0) {
356 if (
options.useFullTileBuffers[opOperand.get()])
358 (*promotedBuffersAndViews)[operandNumber].fullLocalView);
361 (*promotedBuffersAndViews)[operandNumber].partialLocalView);
362 if (operandNumber >= op.getNumDpsInputs())
363 writebackViews.emplace_back(std::make_pair(
365 (*promotedBuffersAndViews)[operandNumber].partialLocalView));
367 opViews.push_back(opOperand.get());
375 for (
auto viewAndPartialLocalView : writebackViews) {
376 if (
failed(
options.copyOutFn(b, viewAndPartialLocalView.second,
377 viewAndPartialLocalView.first)))
382 for (
const auto &pi : *promotedBuffersAndViews)
383 (void)
options.deallocationFn(b, pi.second.fullLocalView);
390 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
392 if (!linalgOp || !linalgOp.hasBufferSemantics())
395 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
397 isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
399 if (!
options.operandsToPromote ||
400 options.operandsToPromote->count(opOperand.getOperandNumber()))
412 LinalgOpInstancePromotionOptions linalgOptions(linalgOp,
options);
413 auto layout = DataLayout::closest(linalgOp);
426 gpu::AddressSpace addressSpace) {
429 func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
436 for (
Value bound : sizeBounds) {
440 shape.push_back(value.getSExtValue());
445 shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
448 if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
449 buffer = builder.
create<memref::AllocOp>(funcOp.getLoc(), type);
450 }
else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
451 buffer = builder.
create<memref::AllocaOp>(funcOp.getLoc(), type);
463 builder, subview, sizeBounds,
464 gpu::GPUDialect::getWorkgroupAddressSpace());
479 b.
create<gpu::BarrierOp>(copyOp->getLoc());
488 builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
The main mechanism for performing data layout queries.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::function< std::optional< Value >(OpBuilder &b, memref::SubViewOp subView, ArrayRef< Value > boundingSubViewSize, DataLayout &layout)> AllocBufferCallbackFn
Callback function type used to perform the allocation for the promoted subView.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
std::function< LogicalResult(OpBuilder &b, Value buffer)> DeallocBufferCallbackFn
Callback function type used to deallocate the buffers used to hold the promoted subview.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< PromotionInfo > promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, const AllocBufferCallbackFn &allocationFn, DataLayout &layout)
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
std::function< LogicalResult(OpBuilder &b, Value src, Value dst)> CopyCallbackFn
Callback function type used to insert copy from original subview to subview of the promoted region fo...
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.