29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
40 using llvm::MapVector;
42 #define DEBUG_TYPE "linalg-promotion"
51 std::optional<unsigned> alignment = std::nullopt) {
52 llvm::TypeSize width = layout.
getTypeSize(elementType);
53 assert(!width.isScalable() &&
"cannot allocate buffer for a scalable vector");
55 IntegerAttr alignmentAttr;
56 if (alignment.has_value())
60 if (
options.memorySpace.has_value())
61 memorySpaceAttr = *
options.memorySpace;
65 auto staticBufferType =
MemRefType::get(width.getFixedValue() * cst.value(),
78 auto dynamicBufferType =
83 b.
create<arith::ConstantIndexOp>(width), allocSize);
85 return b.
create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr);
86 return b.
create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr);
96 std::optional<unsigned> alignment,
DataLayout &layout) {
97 ShapedType viewType = subView.getType();
99 auto zero = b.
create<arith::ConstantIndexOp>(0);
100 auto one = b.
create<arith::ConstantIndexOp>(1);
103 if (
options.memorySpace.has_value())
104 memorySpaceAttr = *
options.memorySpace;
106 Value allocSize = one;
108 allocSize = b.
createOrFold<arith::MulIOp>(allocSize, size.value());
112 ShapedType::kDynamic);
114 auto viewMemRefType =
MemRefType::get(dynSizes, viewType.getElementType());
118 boundingSubViewSize);
129 auto viewOp = cast<memref::ViewOp>(fullLocalView.
getDefiningOp());
130 b.
create<memref::DeallocOp>(viewOp.getSource().getLoc(),
142 struct LinalgOpInstancePromotionOptions {
143 LinalgOpInstancePromotionOptions(LinalgOp op,
146 MapVector<int64_t, Value> subViews;
148 llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn;
160 std::optional<unsigned> alignment;
164 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
166 : subViews(), alignment(
options.alignment) {
167 assert(linalgOp.hasPureBufferSemantics() &&
168 "revisit usage of shaped operand");
169 auto vUseFullTileBuffers =
170 options.useFullTileBuffers.value_or(llvm::SmallBitVector());
171 vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
172 options.useFullTileBuffersDefault);
174 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
175 int64_t operandNumber = opOperand.getOperandNumber();
176 if (
options.operandsToPromote &&
177 !
options.operandsToPromote->count(operandNumber))
179 Operation *op = opOperand.get().getDefiningOp();
180 if (
auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
181 subViews[operandNumber] = sv;
184 if (!isa<linalg::GenericOp>(linalgOp) ||
185 linalgOp.payloadUsesValueFromOperand(&opOperand))
186 operandsNumbersToCopyIn.insert(operandNumber);
187 useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
192 allocationFn = *
options.allocationFn;
194 allocationFn = [&](
OpBuilder &b, memref::SubViewOp subViewOp,
198 boundingSubViewSize, alignment, layout);
203 deallocationFn = *
options.deallocationFn;
213 Value dst) -> LogicalResult {
214 b.
create<linalg::CopyOp>(loc, src, dst);
217 copyInFn = (
options.copyInFn ? *(
options.copyInFn) : defaultCopyCallBack);
218 copyOutFn = (
options.copyOutFn ? *(
options.copyOutFn) : defaultCopyCallBack);
241 auto viewType = subView.getType();
242 auto rank = viewType.getRank();
245 fullSizes.reserve(rank);
246 partialSizes.reserve(rank);
247 llvm::SmallBitVector droppedDims = subView.getDroppedDims();
248 int64_t resultDimIdx = 0;
249 for (
const auto &en :
llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
250 if (droppedDims[en.index()])
252 auto rangeValue = en.value();
255 LLVM_DEBUG(llvm::dbgs() <<
"Extract tightest: " << rangeValue.size <<
"\n");
257 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
260 FailureOr<int64_t> upperBound =
261 ValueBoundsConstraintSet::computeConstantBound(
262 presburger::BoundType::UB, rangeValue.size,
264 size = failed(upperBound)
268 LLVM_DEBUG(llvm::dbgs() <<
"Extracted tightest: " << size <<
"\n");
269 fullSizes.push_back(size);
270 partialSizes.push_back(
271 b.
createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
276 std::optional<Value> fullLocalView =
277 allocationFn(b, subView, fullSizes, layout);
282 auto partialLocalView = b.
createOrFold<memref::SubViewOp>(
283 loc, *fullLocalView, zeros, partialSizes, ones);
287 static FailureOr<MapVector<int64_t, PromotionInfo>>
293 MapVector<int64_t, PromotionInfo> promotionInfoMap;
295 for (
auto v :
options.subViews) {
296 memref::SubViewOp subView =
297 cast<memref::SubViewOp>(v.second.getDefiningOp());
300 if (failed(promotionInfo))
302 promotionInfoMap[v.first] = *promotionInfo;
305 if (!
options.useFullTileBuffers[v.second])
307 Type subviewEltType = subView.getType().getElementType();
313 .Case([&](IntegerType t) {
316 .Case([&](ComplexType t) {
318 if (
auto et = dyn_cast<FloatType>(t.getElementType()))
320 else if (
auto et = cast<IntegerType>(t.getElementType()))
322 return b.
create<complex::CreateOp>(t, tmp, tmp);
324 .Default([](
auto) {
return Value(); });
327 b.
create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView);
331 for (
auto v :
options.subViews) {
332 auto *info = promotionInfoMap.find(v.first);
333 if (info == promotionInfoMap.end())
335 if (
options.operandsNumbersToCopyIn.count(v.first) == 0)
338 b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
339 info->second.partialLocalView)))
342 return promotionInfoMap;
345 static FailureOr<LinalgOp>
348 assert(op.hasPureBufferSemantics() &&
349 "expected linalg op with buffer semantics");
353 if (failed(promotedBuffersAndViews) ||
354 promotedBuffersAndViews->size() !=
options.subViews.size())
361 opViews.reserve(op->getNumOperands());
363 writebackViews.reserve(promotedBuffersAndViews->size());
364 for (
OpOperand &opOperand : op->getOpOperands()) {
365 int64_t operandNumber = opOperand.getOperandNumber();
366 if (
options.subViews.count(operandNumber) != 0) {
367 if (
options.useFullTileBuffers[opOperand.get()])
369 (*promotedBuffersAndViews)[operandNumber].fullLocalView);
372 (*promotedBuffersAndViews)[operandNumber].partialLocalView);
373 if (operandNumber >= op.getNumDpsInputs())
374 writebackViews.emplace_back(std::make_pair(
376 (*promotedBuffersAndViews)[operandNumber].partialLocalView));
378 opViews.push_back(opOperand.get());
381 op->setOperands(0, opViews.size(), opViews);
386 for (
auto viewAndPartialLocalView : writebackViews) {
387 if (failed(
options.copyOutFn(b, viewAndPartialLocalView.second,
388 viewAndPartialLocalView.first)))
393 for (
const auto &pi : *promotedBuffersAndViews)
394 (void)
options.deallocationFn(b, pi.second.fullLocalView);
401 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
403 if (!linalgOp || !linalgOp.hasPureBufferSemantics())
406 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
408 isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
410 if (!
options.operandsToPromote ||
411 options.operandsToPromote->count(opOperand.getOperandNumber()))
423 LinalgOpInstancePromotionOptions linalgOptions(linalgOp,
options);
424 auto layout = DataLayout::closest(linalgOp);
437 gpu::AddressSpace addressSpace) {
440 func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
447 for (
Value bound : sizeBounds) {
451 shape.push_back(value.getSExtValue());
456 shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
459 if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
460 buffer = builder.
create<memref::AllocOp>(funcOp.getLoc(), type);
461 }
else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
462 buffer = builder.
create<memref::AllocaOp>(funcOp.getLoc(), type);
474 builder, subview, sizeBounds,
475 gpu::GPUDialect::getWorkgroupAddressSpace());
490 b.
create<gpu::BarrierOp>(copyOp->getLoc());
499 builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
void createOrFold(llvm::SmallVectorImpl< Value > &results, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::function< std::optional< Value >(OpBuilder &b, memref::SubViewOp subView, ArrayRef< Value > boundingSubViewSize, DataLayout &layout)> AllocBufferCallbackFn
Callback function type used to perform the allocation for the promoted subView.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
std::function< LogicalResult(OpBuilder &b, Value buffer)> DeallocBufferCallbackFn
Callback function type used to deallocate the buffers used to hold the promoted subview.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< PromotionInfo > promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, const AllocBufferCallbackFn &allocationFn, DataLayout &layout)
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
std::function< LogicalResult(OpBuilder &b, Value src, Value dst)> CopyCallbackFn
Callback function type used to insert copy from original subview to subview of the promoted region fo...
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...