25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
35 using llvm::MapVector;
37 #define DEBUG_TYPE "linalg-promotion"
46 std::optional<unsigned> alignment = std::nullopt) {
47 llvm::TypeSize width = layout.
getTypeSize(elementType);
48 assert(!width.isScalable() &&
"cannot allocate buffer for a scalable vector");
50 IntegerAttr alignmentAttr;
51 if (alignment.has_value())
55 if (
options.memorySpace.has_value())
56 memorySpaceAttr = *
options.memorySpace;
60 auto staticBufferType =
MemRefType::get(width.getFixedValue() * cst.value(),
65 return memref::AllocaOp::create(b, staticBufferType,
ValueRange{},
68 return memref::AllocOp::create(b, staticBufferType,
ValueRange{},
73 auto dynamicBufferType =
80 return memref::AllocaOp::create(b, dynamicBufferType, mul, alignmentAttr);
81 return memref::AllocOp::create(b, dynamicBufferType, mul, alignmentAttr);
91 std::optional<unsigned> alignment,
DataLayout &layout) {
92 ShapedType viewType = subView.getType();
98 if (
options.memorySpace.has_value())
99 memorySpaceAttr = *
options.memorySpace;
101 Value allocSize = one;
103 allocSize = b.
createOrFold<arith::MulIOp>(allocSize, size.value());
107 ShapedType::kDynamic);
109 auto viewMemRefType =
MemRefType::get(dynSizes, viewType.getElementType());
113 boundingSubViewSize);
124 auto viewOp = cast<memref::ViewOp>(fullLocalView.
getDefiningOp());
125 memref::DeallocOp::create(b, viewOp.getSource().getLoc(),
137 struct LinalgOpInstancePromotionOptions {
138 LinalgOpInstancePromotionOptions(LinalgOp op,
141 MapVector<int64_t, Value> subViews;
143 llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn;
148 bool useOriginalSubviewSize;
158 std::optional<unsigned> alignment;
162 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
164 : subViews(), alignment(
options.alignment) {
165 assert(linalgOp.hasPureBufferSemantics() &&
166 "revisit usage of shaped operand");
167 auto vUseFullTileBuffers =
168 options.useFullTileBuffers.value_or(llvm::SmallBitVector());
169 vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
170 options.useFullTileBuffersDefault);
171 useOriginalSubviewSize =
options.useOriginalSubviewSize;
173 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
174 int64_t operandNumber = opOperand.getOperandNumber();
175 if (
options.operandsToPromote &&
176 !
options.operandsToPromote->count(operandNumber))
178 Operation *op = opOperand.get().getDefiningOp();
179 if (
auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
180 subViews[operandNumber] = sv;
183 if (!isa<linalg::GenericOp>(linalgOp) ||
184 linalgOp.payloadUsesValueFromOperand(&opOperand))
185 operandsNumbersToCopyIn.insert(operandNumber);
186 useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
191 allocationFn = *
options.allocationFn;
193 allocationFn = [&](
OpBuilder &b, memref::SubViewOp subViewOp,
197 boundingSubViewSize, alignment, layout);
202 deallocationFn = *
options.deallocationFn;
212 Value dst) -> LogicalResult {
213 linalg::CopyOp::create(b, loc, src, dst);
216 copyInFn = (
options.copyInFn ? *(
options.copyInFn) : defaultCopyCallBack);
217 copyOutFn = (
options.copyOutFn ? *(
options.copyOutFn) : defaultCopyCallBack);
241 auto viewType = subView.getType();
242 auto rank = viewType.getRank();
245 fullSizes.reserve(rank);
246 partialSizes.reserve(rank);
247 llvm::SmallBitVector droppedDims = subView.getDroppedDims();
248 int64_t resultDimIdx = 0;
249 for (
const auto &en :
llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
250 if (droppedDims[en.index()])
252 auto rangeValue = en.value();
255 LLVM_DEBUG(llvm::dbgs() <<
"Extract tightest: " << rangeValue.size <<
"\n");
257 if (llvm::isa_and_present<Attribute>(rangeValue.size) ||
258 useOriginalSubviewSize) {
261 FailureOr<int64_t> upperBound =
262 ValueBoundsConstraintSet::computeConstantBound(
263 presburger::BoundType::UB, rangeValue.size,
267 : arith::ConstantIndexOp::create(b, loc, *upperBound);
269 LLVM_DEBUG(llvm::dbgs() <<
"Extracted tightest: " << size <<
"\n");
270 fullSizes.push_back(size);
271 partialSizes.push_back(
272 b.
createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
276 std::optional<Value> fullLocalView =
277 allocationFn(b, subView, fullSizes, layout);
282 auto partialLocalView = b.
createOrFold<memref::SubViewOp>(
283 loc, *fullLocalView, zeros, partialSizes, ones);
287 static FailureOr<MapVector<int64_t, PromotionInfo>>
293 MapVector<int64_t, PromotionInfo> promotionInfoMap;
295 for (
auto v :
options.subViews) {
296 memref::SubViewOp subView =
297 cast<memref::SubViewOp>(v.second.getDefiningOp());
301 if (
failed(promotionInfo))
303 promotionInfoMap[v.first] = *promotionInfo;
306 if (!
options.useFullTileBuffers[v.second])
308 Type subviewEltType = subView.getType().getElementType();
311 .Case([&](FloatType t) {
314 .Case([&](IntegerType t) {
317 .Case([&](ComplexType t) {
319 if (
auto et = dyn_cast<FloatType>(t.getElementType()))
321 else if (
auto et = cast<IntegerType>(t.getElementType()))
323 return complex::CreateOp::create(b, t, tmp, tmp);
325 .Default([](
auto) {
return Value(); });
328 linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
332 for (
auto v :
options.subViews) {
333 auto *info = promotionInfoMap.find(v.first);
334 if (info == promotionInfoMap.end())
336 if (
options.operandsNumbersToCopyIn.count(v.first) == 0)
339 b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
340 info->second.partialLocalView)))
343 return promotionInfoMap;
346 static FailureOr<LinalgOp>
349 assert(op.hasPureBufferSemantics() &&
350 "expected linalg op with buffer semantics");
354 if (
failed(promotedBuffersAndViews) ||
355 promotedBuffersAndViews->size() !=
options.subViews.size())
362 opViews.reserve(op->getNumOperands());
364 writebackViews.reserve(promotedBuffersAndViews->size());
365 for (
OpOperand &opOperand : op->getOpOperands()) {
366 int64_t operandNumber = opOperand.getOperandNumber();
367 if (
options.subViews.count(operandNumber) != 0) {
368 if (
options.useFullTileBuffers[opOperand.get()])
370 (*promotedBuffersAndViews)[operandNumber].fullLocalView);
373 (*promotedBuffersAndViews)[operandNumber].partialLocalView);
374 if (operandNumber >= op.getNumDpsInputs())
375 writebackViews.emplace_back(std::make_pair(
377 (*promotedBuffersAndViews)[operandNumber].partialLocalView));
379 opViews.push_back(opOperand.get());
382 op->setOperands(0, opViews.size(), opViews);
387 for (
auto viewAndPartialLocalView : writebackViews) {
388 if (
failed(
options.copyOutFn(b, viewAndPartialLocalView.second,
389 viewAndPartialLocalView.first)))
394 for (
const auto &pi : *promotedBuffersAndViews)
395 (void)
options.deallocationFn(b, pi.second.fullLocalView);
402 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
404 if (!linalgOp || !linalgOp.hasPureBufferSemantics())
407 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
409 isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
411 if (!
options.operandsToPromote ||
412 options.operandsToPromote->count(opOperand.getOperandNumber()))
424 LinalgOpInstancePromotionOptions linalgOptions(linalgOp,
options);
425 auto layout = DataLayout::closest(linalgOp);
438 gpu::AddressSpace addressSpace) {
441 func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
448 for (
Value bound : sizeBounds) {
452 shape.push_back(value.getSExtValue());
457 shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
460 if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
461 buffer = memref::AllocOp::create(builder, funcOp.getLoc(), type);
462 }
else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
463 buffer = memref::AllocaOp::create(builder, funcOp.getLoc(), type);
475 builder, subview, sizeBounds,
476 gpu::GPUDialect::getWorkgroupAddressSpace());
489 gpu::BarrierOp::create(b, src.
getLoc());
491 gpu::BarrierOp::create(b, copyOp->
getLoc());
500 builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
506 memref::CopyOp::create(b, src.
getLoc(), src, dst);
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::function< std::optional< Value >(OpBuilder &b, memref::SubViewOp subView, ArrayRef< Value > boundingSubViewSize, DataLayout &layout)> AllocBufferCallbackFn
Callback function type used to perform the allocation for the promoted subView.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
std::function< LogicalResult(OpBuilder &b, Value buffer)> DeallocBufferCallbackFn
Callback function type used to deallocate the buffers used to hold the promoted subview.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< PromotionInfo > promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, bool useOriginalSubviewSize, const AllocBufferCallbackFn &allocationFn, DataLayout &layout)
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
std::function< LogicalResult(OpBuilder &b, Value src, Value dst)> CopyCallbackFn
Callback function type used to insert copy from original subview to subview of the promoted region fo...
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...