MLIR 23.0.0git
XeGPUTransformOps.cpp
Go to the documentation of this file.
1//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/SmallVectorExtras.h"
16
17#include <optional>
18
19#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "xegpu-transforms"
21
22using namespace mlir;
23using namespace mlir::transform;
24
25/// Assuming that `ofr` is an index attr or a param of index type
26/// or a transform dialect handle mapped to exactly one op
27/// with one index result, get that value and cast it to int type.
29 transform::TransformState &state, TransformOpInterface transformOp,
31 for (OpFoldResult ofr : ofrs) {
32 // Attribute case.
33 if (auto attr = dyn_cast<Attribute>(ofr)) {
34 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
35 result.push_back(intAttr.getInt());
36 continue;
37 }
38 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
39 }
40
41 // Transform param case.
42 Value transformValue = cast<Value>(ofr);
43 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
44 ArrayRef<Attribute> params = state.getParams(transformValue);
45 if (params.size() != 1)
46 return transformOp.emitDefiniteFailure()
47 << "requires exactly one parameter associated";
48 result.push_back(
49 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
50 continue;
51 }
52
53 // Payload value case.
54 auto payloadOps = state.getPayloadOps(transformValue);
55 if (!llvm::hasSingleElement(payloadOps)) {
57 transformOp.emitSilenceableError()
58 << "handle must be mapped to exactly one payload op";
59 diag.attachNote(transformValue.getLoc())
60 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
61 return diag;
62 }
63
64 Operation *op = *payloadOps.begin();
65 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
67 transformOp.emitSilenceableError()
68 << "payload op must have exactly 1 index result";
69 diag.attachNote(op->getLoc())
70 << "has " << op->getNumResults() << " results";
71 return diag;
72 }
73
74 IntegerAttr intAttr;
75 if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
76 return transformOp.emitSilenceableError()
77 << "requires param or handle to be the result of a constant like "
78 "op";
79
80 result.push_back(intAttr.getInt());
81 }
83}
84
85/// Find producer operation of type T for the given value.
86/// It's assumed that producer ops are chained through their first operand.
87/// Producer chain is traced trough loop block arguments (init values).
88template <typename T>
89static std::optional<T> findProducerOfType(Value val) {
90 Value currentValue = val;
91 if (!currentValue.getDefiningOp()) {
92 // Value may be a block argument initialized outside a loop.
93 if (val.getNumUses() == 0) {
94 LDBG() << "Failed to find producer op, value has no uses.";
95 return std::nullopt;
96 }
97 auto userOp = val.getUsers().begin();
98 auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
99 if (!parentLoop) {
100 LDBG() << "Failed to find producer op, not in a loop.";
101 return std::nullopt;
102 }
103 int64_t iterArgIdx;
104 if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
105 auto numInductionVars = parentLoop.getLoopInductionVars()->size();
106 iterArgIdx = iterArg.getArgNumber() - numInductionVars;
107 currentValue = parentLoop.getInits()[iterArgIdx];
108 } else {
109 LDBG() << "Failed to find producer op, value not in init values.";
110 return std::nullopt;
111 }
112 }
113 Operation *producerOp = currentValue.getDefiningOp();
114
115 if (auto matchingOp = dyn_cast<T>(producerOp))
116 return matchingOp;
117
118 if (producerOp->getNumOperands() == 0)
119 return std::nullopt;
120
121 return findProducerOfType<T>(producerOp->getOperand(0));
122}
123
124/// Create a layout attribute from the given parameters.
125static xegpu::LayoutAttr
127 ArrayRef<int32_t> sgData,
128 std::optional<ArrayRef<int32_t>> instData) {
129 return xegpu::LayoutAttr::get(
130 ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
131 DenseI32ArrayAttr::get(ctx, sgData),
132 instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
133 /*lane_layout=*/nullptr,
134 /*lane_data=*/nullptr,
135 /*order=*/nullptr);
136}
137
138/// Generate `xegpu::LayoutAttr` from op mixed layout values.
141 TransformOpInterface transformOp,
142 ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
144 ArrayRef<::mlir::OpFoldResult> mixedInstData,
145 xegpu::LayoutAttr &layoutAttr) {
146 SmallVector<int32_t> sgLayout, sgData, instData;
147 auto status =
148 convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
149 if (!status.succeeded())
150 return status;
151
152 status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
153 if (!status.succeeded())
154 return status;
155
156 status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
157 if (!status.succeeded())
158 return status;
159 auto maybeInstData = instData.empty()
160 ? std::nullopt
161 : std::optional<ArrayRef<int32_t>>(instData);
162
163 layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
164
166}
167
168/// Replace xegpu.create_nd_desc op with a new one with the given layout.
169static xegpu::CreateNdDescOp
171 xegpu::CreateNdDescOp descOp,
172 xegpu::DistributeLayoutAttr layout) {
173 assert(descOp.getMixedOffsets().size() == 0 &&
174 "create desc op with offsets is not supported");
175 auto oldTensorDesc = descOp.getType();
176 auto descType = xegpu::TensorDescType::get(
177 oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
178 /*array_length=*/oldTensorDesc.getArrayLength(),
179 /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
180 /*memory_space=*/oldTensorDesc.getMemorySpace(),
181 /*layout=*/layout);
182
183 rewriter.setInsertionPointAfter(descOp);
184 auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
185 descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
186 descOp.getMixedStrides());
187 return newDescOp;
188}
189
191transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
194 auto targetValues = state.getPayloadValues(getTarget());
195 if (!llvm::hasSingleElement(targetValues)) {
196 return emitDefiniteFailure()
197 << "requires exactly one target value handle (got "
198 << llvm::range_size(targetValues) << ")";
199 }
200
201 auto maybeDescOp =
202 findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
203 if (!maybeDescOp) {
204 return emitSilenceableFailure(getLoc())
205 << "Could not find a matching descriptor op when walking the "
206 "producer chain of the first operand.";
207 }
208
209 results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
211}
212
213void transform::SetDescLayoutOp::build(OpBuilder &builder,
215 ArrayRef<OpFoldResult> mixedSgLayout,
216 ArrayRef<OpFoldResult> mixedSgData,
217 ArrayRef<OpFoldResult> mixedInstData,
218 ArrayRef<int64_t> sliceDims) {
219 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
220 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
221 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
222 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
223 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
224 build(builder, result, target.getType(),
225 /*target=*/target,
226 /*sg_layout=*/dynamicSgLayout,
227 /*sg_data=*/dynamicSgData,
228 /*inst_data=*/dynamicInstData,
229 /*static_sg_layout=*/staticSgLayout,
230 /*static_sg_data=*/staticSgData,
231 /*static_inst_data=*/staticInstData,
232 /*slice_dims=*/sliceDims);
233}
234
236transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
239 auto targetOps = state.getPayloadOps(getTarget());
240 if (!llvm::hasSingleElement(targetOps)) {
241 return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
242 << llvm::range_size(targetOps) << ")";
243 }
244 Operation *target = *targetOps.begin();
245
246 xegpu::LayoutAttr layoutAttr = nullptr;
247 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
248 getMixedSgLayout(), getMixedSgData(),
249 getMixedInstData(), layoutAttr);
250 if (!status.succeeded())
251 return status;
252
253 xegpu::DistributeLayoutAttr layout = layoutAttr;
254 auto sliceDims = getSliceDims();
255 if (sliceDims.size() > 0) {
256 // Wrap layoutAttr in a slice attribute.
257 layout = xegpu::SliceAttr::get(
258 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
259 }
260
261 // For now only create_nd_desc op is supported.
262 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
263 if (!descOp) {
264 auto diag = emitSilenceableFailure(getLoc())
265 << "Expected a xegpu.create_nd_desc op, but got: "
266 << target->getName();
267 diag.attachNote(target->getLoc()) << "target op";
268 return diag;
269 }
270
271 // Set layout attr in desc op's return type. Replaces old desc op.
272 auto newdescOp = setDescLayout(rewriter, descOp, layout);
273
274 // Map result handles.
275 results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
276
278}
279
280void transform::SetDescLayoutOp::getEffects(
282 consumesHandle(getTargetMutable(), effects);
283 onlyReadsHandle(getSgLayoutMutable(), effects);
284 onlyReadsHandle(getSgDataMutable(), effects);
285 onlyReadsHandle(getInstDataMutable(), effects);
286 producesHandle(getOperation()->getOpResults(), effects);
287 modifiesPayload(effects);
288}
289
290void transform::SetOpLayoutAttrOp::build(
291 OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
292 ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
293 ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
294 bool result, bool operand) {
295 SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
296 SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
297 dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
298 dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
299 dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
300 build(builder, ostate, target.getType(),
301 /*target=*/target,
302 /*index=*/index,
303 /*sg_layout=*/dynamicSgLayout,
304 /*sg_data=*/dynamicSgData,
305 /*inst_data=*/dynamicInstData,
306 /*static_sg_layout=*/staticSgLayout,
307 /*static_sg_data=*/staticSgData,
308 /*static_inst_data=*/staticInstData,
309 /*slice_dims=*/sliceDims,
310 /*result=*/result,
311 /*operand=*/operand);
312}
313
315transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
318 auto targetOps = state.getPayloadOps(getTarget());
319 if (!llvm::hasSingleElement(targetOps)) {
320 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
321 << llvm::range_size(targetOps) << ")";
322 }
323 Operation *target = *targetOps.begin();
324
325 bool resultTarget = getResult();
326 bool operandTarget = getOperand();
327
329 if (resultTarget && index >= target->getNumResults()) {
330 return emitSilenceableFailure(getLoc())
331 << "Index exceeds the number of op results";
332 }
333 if (operandTarget && index >= target->getNumOperands()) {
334 return emitSilenceableFailure(getLoc())
335 << "Index exceeds the number of op operands";
336 }
337
338 xegpu::LayoutAttr layoutAttr = nullptr;
339 auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
340 getMixedSgLayout(), getMixedSgData(),
341 getMixedInstData(), layoutAttr);
342 if (!status.succeeded())
343 return status;
344
345 xegpu::DistributeLayoutAttr layout = layoutAttr;
346 auto sliceDims = getSliceDims();
347 if (sliceDims.size() > 0) {
348 // Wrap layoutAttr in a slice attribute.
349 layout = xegpu::SliceAttr::get(
350 getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
351 }
352
353 // Set layout attribute
354 if (resultTarget) {
355 // op result
356 xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
357 } else if (operandTarget) {
358 // op operand
359 xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
360 } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) {
361 // dpas op is a special case where layout needs to be set for A, B, and C
362 if (index == 0)
363 dpasOp.getProperties().layout_a = layout;
364 else if (index == 1)
365 dpasOp.getProperties().layout_b = layout;
366 else if (index == 2)
367 dpasOp.getProperties().layout_cd = layout;
368 else {
369 auto diag = emitSilenceableFailure(getLoc())
370 << "Invalid index for setting dpas op layout: " << index;
371 diag.attachNote(target->getLoc()) << "target op";
372 return diag;
373 }
374 } else {
375 // op's anchor layout.
376 auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target);
377 if (!anchorOp) {
378 auto diag = emitSilenceableFailure(getLoc())
379 << "Cannot set anchor layout to op: " << target->getName();
380 diag.attachNote(target->getLoc()) << "target op";
381 return diag;
382 }
383 anchorOp.setAnchorLayout(layout);
384 }
386}
387
388void transform::SetOpLayoutAttrOp::getEffects(
390 onlyReadsHandle(getTargetMutable(), effects);
391 onlyReadsHandle(getSgLayoutMutable(), effects);
392 onlyReadsHandle(getSgDataMutable(), effects);
393 onlyReadsHandle(getInstDataMutable(), effects);
394 modifiesPayload(effects);
395}
396
397LogicalResult transform::SetOpLayoutAttrOp::verify() {
398 if (getResult() && getOperand()) {
399 return emitOpError("Cannot set both result and operand simultaneously.");
400 }
401 return success();
402}
403
404void transform::SetGPULaunchThreadsOp::build(
405 OpBuilder &builder, OperationState &ostate, Value target,
406 ArrayRef<OpFoldResult> mixedThreads) {
407 SmallVector<int64_t> staticThreads;
408 SmallVector<Value> dynamicThreads;
409 dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
410 build(builder, ostate, target.getType(),
411 /*target=*/target,
412 /*threads=*/dynamicThreads,
413 /*static_threads=*/staticThreads);
414}
415
417transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
420 auto targetOps = state.getPayloadOps(getTarget());
421 if (!llvm::hasSingleElement(targetOps)) {
422 return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
423 << llvm::range_size(targetOps) << ")";
424 }
425 Operation *target = *targetOps.begin();
426
427 auto launchOp = dyn_cast<gpu::LaunchOp>(target);
428 if (!launchOp) {
429 auto diag = emitSilenceableFailure(getLoc())
430 << "Expected a gpu.launch op, but got: " << target->getName();
431 diag.attachNote(target->getLoc()) << "target op";
432 return diag;
433 }
434
435 SmallVector<int32_t> threads;
437 convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
438 if (!status.succeeded())
439 return status;
440
441 if (threads.size() != 3) {
442 return emitSilenceableFailure(getLoc())
443 << "Expected threads argument to consist of three values (got "
444 << threads.size() << ")";
445 }
446
447 rewriter.setInsertionPoint(launchOp);
448 auto createConstValue = [&](int value) {
449 return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
450 };
451
452 // Replace threads in-place.
453 launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
454 launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
455 launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
456
458}
459
460void transform::SetGPULaunchThreadsOp::getEffects(
462 onlyReadsHandle(getTargetMutable(), effects);
463 onlyReadsHandle(getThreadsMutable(), effects);
464 modifiesPayload(effects);
465}
466
468transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
471 auto targetValues = state.getPayloadValues(getTarget());
472 if (!llvm::hasSingleElement(targetValues))
473 return emitDefiniteFailure()
474 << "requires exactly one target value handle (got "
475 << llvm::range_size(targetValues) << ")";
476 auto value = *targetValues.begin();
477
478 int64_t nbPrefetch = getStaticNbPrefetch();
479 if (getDynamicNbPrefetch()) {
480 // Get dynamic prefetch count from transform param or handle.
481 SmallVector<int32_t> dynamicNbPrefetch;
482 auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
483 {getDynamicNbPrefetch()});
484 if (!status.succeeded())
485 return status;
486 if (dynamicNbPrefetch.size() != 1)
487 return emitDefiniteFailure()
488 << "requires exactly one value for dynamic_nb_prefetch";
489 nbPrefetch = dynamicNbPrefetch[0];
490 }
491 if (nbPrefetch <= 0)
492 return emitSilenceableFailure(getLoc())
493 << "nb_prefetch must be a positive integer.";
494
495 // Find load operation of the operand.
496 auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
497 if (!maybeLoadOp)
498 return emitSilenceableFailure(getLoc()) << "Could not find load op.";
499 auto loadOp = *maybeLoadOp;
500 if (loadOp.getMixedOffsets().size() == 0) {
501 auto diag = emitSilenceableFailure(getLoc())
502 << "Load op must have offsets.";
503 diag.attachNote(loadOp.getLoc()) << "load op";
504 return diag;
505 }
506
507 // Find the parent scf.for loop.
508 auto forOp = loadOp->getParentOfType<scf::ForOp>();
509 if (!forOp) {
510 auto diag = emitSilenceableFailure(getLoc())
511 << "Load op is not contained in a scf.for loop.";
512 diag.attachNote(loadOp.getLoc()) << "load op";
513 return diag;
514 }
515
516 // Find descriptor op.
517 auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
518 if (!maybeDescOp)
519 return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
520 auto descOp = *maybeDescOp;
521 if (descOp.getMixedOffsets().size() > 0) {
522 auto diag = emitSilenceableFailure(getLoc())
523 << "desc op with offsets is not supported.";
524 diag.attachNote(descOp.getLoc()) << "desc op";
525 }
526
527 // Clone desc op outside the loop.
528 rewriter.setInsertionPoint(forOp);
529 auto newDescOp =
530 cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
531
532 // Clone reduction loop to emit initial prefetches.
533 // Compute upper bound of the init loop: start + nbPrefetch * step.
534 auto nbPrefetchCst =
535 arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
536 auto nbStep = rewriter.createOrFold<arith::MulIOp>(
537 forOp.getLoc(), nbPrefetchCst, forOp.getStep());
538 auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
539 forOp.getLoc(), forOp.getLowerBound(), nbStep);
540 auto initForOp =
541 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
542 initUpBound, forOp.getStep());
543
544 auto ctx = rewriter.getContext();
545 auto readCacheHint =
546 xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
547
548 // Modify loadOp mixedOffsets by replacing the for loop induction variable
549 // with the given value.
550 auto getPrefetchOffsets =
551 [&](Value replacementVal) -> SmallVector<OpFoldResult> {
552 IRMapping mapping;
553 mapping.map(forOp.getInductionVar(), replacementVal);
554 SmallVector<Value> dynamicOffsets =
555 llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) {
556 return mapping.lookupOrDefault(v);
557 });
558 auto constOffsets = loadOp.getConstOffsets().value();
559 return getMixedValues(constOffsets, dynamicOffsets, ctx);
560 };
561
562 // Insert prefetch op in init loop.
563 // Replace induction var with the init loop induction var.
564 rewriter.setInsertionPointToStart(initForOp.getBody());
565 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
566 newDescOp.getResult(),
567 getPrefetchOffsets(initForOp.getInductionVar()),
568 readCacheHint, readCacheHint, readCacheHint,
569 /*layout=*/nullptr);
570
571 // Insert prefetch op in main loop.
572 // Calculate prefetch offset after the init prefetches have been issued.
573 rewriter.setInsertionPointToStart(forOp.getBody());
574 auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
575 forOp.getInductionVar(), nbStep);
576 // Replace induction var with correct offset.
577 xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
578 newDescOp.getResult(),
579 getPrefetchOffsets(prefetchOffset), readCacheHint,
580 readCacheHint, readCacheHint, /*layout=*/nullptr);
581
582 // Unroll the init loop.
583 if (failed(loopUnrollFull(initForOp)))
584 return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
585
586 results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
587
589}
590
591void transform::InsertPrefetchOp::getEffects(
593 onlyReadsHandle(getTargetMutable(), effects);
594 onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
595 producesHandle(getOperation()->getOpResults(), effects);
596 modifiesPayload(effects);
597}
598
599void transform::ConvertLayoutOp::build(
600 OpBuilder &builder, OperationState &ostate, Value target,
601 ArrayRef<OpFoldResult> mixedInputSgLayout,
602 ArrayRef<OpFoldResult> mixedInputSgData,
603 ArrayRef<OpFoldResult> mixedInputInstData,
604 ArrayRef<OpFoldResult> mixedTargetSgLayout,
605 ArrayRef<OpFoldResult> mixedTargetSgData,
606 ArrayRef<OpFoldResult> mixedTargetInstData) {
607 SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
608 staticInputInstData;
609 SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
610 dynamicInputInstData;
611 dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout,
612 staticInputSgLayout);
613 dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData,
614 staticInputSgData);
615 dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData,
616 staticInputInstData);
617 SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData,
618 staticTargetInstData;
619 SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData,
620 dynamicTargetInstData;
621 dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout,
622 staticTargetSgLayout);
623 dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData,
624 staticTargetSgData);
625 dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData,
626 staticTargetInstData);
627 build(builder, ostate, target.getType(),
628 /*target=*/target,
629 /*input_sg_layout=*/dynamicInputSgLayout,
630 /*input_sg_data=*/dynamicInputSgData,
631 /*input_inst_data=*/dynamicInputInstData,
632 /*target_sg_layout=*/dynamicTargetSgLayout,
633 /*target_sg_data=*/dynamicTargetSgData,
634 /*target_inst_data=*/dynamicTargetInstData,
635 /*static_input_sg_layout=*/staticInputSgLayout,
636 /*static_input_sg_data=*/staticInputSgData,
637 /*static_input_inst_data=*/staticInputInstData,
638 /*static_target_sg_layout=*/staticTargetSgLayout,
639 /*static_target_sg_data=*/staticTargetSgData,
640 /*static_target_inst_data=*/staticTargetInstData);
641}
642
644transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
647 auto targetValues = state.getPayloadValues(getTarget());
648 if (!llvm::hasSingleElement(targetValues))
649 return emitDefiniteFailure()
650 << "requires exactly one target value handle (got "
651 << llvm::range_size(targetValues) << ")";
652 auto value = *targetValues.begin();
653
654 // Construct layout attributes.
655 xegpu::LayoutAttr inputLayoutAttr = nullptr;
656 auto status = getLayoutAttrFromOperands(
657 getContext(), state, (*this), getMixedInputSgLayout(),
658 getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
659 if (!status.succeeded())
660 return status;
661
662 xegpu::LayoutAttr targetLayoutAttr = nullptr;
664 getContext(), state, (*this), getMixedTargetSgLayout(),
665 getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
666 if (!status.succeeded())
667 return status;
668
669 // Find first user op to define insertion point for layout conversion.
670 if (value.use_empty())
671 return emitSilenceableFailure(getLoc())
672 << "Value has no users to insert layout conversion.";
673 Operation *userOp = *value.getUsers().begin();
674
675 // Emit convert_layout op.
676 rewriter.setInsertionPoint(userOp);
677 auto convLayoutOp =
678 xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(),
679 value, inputLayoutAttr, targetLayoutAttr);
680 // Replace load op result with the converted layout.
681 rewriter.replaceUsesWithIf(
682 value, convLayoutOp.getResult(), [&](OpOperand &use) {
683 return use.getOwner() != convLayoutOp.getOperation();
684 });
685
686 results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
688}
689
690void transform::ConvertLayoutOp::getEffects(
692 onlyReadsHandle(getTargetMutable(), effects);
693 onlyReadsHandle(getInputSgLayoutMutable(), effects);
694 onlyReadsHandle(getInputSgDataMutable(), effects);
695 onlyReadsHandle(getInputInstDataMutable(), effects);
696 onlyReadsHandle(getTargetSgLayoutMutable(), effects);
697 onlyReadsHandle(getTargetSgDataMutable(), effects);
698 onlyReadsHandle(getTargetInstDataMutable(), effects);
699 producesHandle(getOperation()->getOpResults(), effects);
700 modifiesPayload(effects);
701}
702
703namespace {
704class XeGPUTransformDialectExtension
706 XeGPUTransformDialectExtension> {
707public:
708 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
709
710 using Base::Base;
711
712 void init();
713};
714
715void XeGPUTransformDialectExtension::init() {
716 declareGeneratedDialect<scf::SCFDialect>();
717 declareGeneratedDialect<arith::ArithDialect>();
718 declareGeneratedDialect<xegpu::XeGPUDialect>();
719
720 registerTransformOps<
721#define GET_OP_LIST
722#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
723 >();
724}
725} // namespace
726
727#define GET_OP_CLASSES
728#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
729
731 registry.addExtensions<XeGPUTransformDialectExtension>();
732}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< T > findProducerOfType(Value val)
Find producer operation of type T for the given value.
static DiagnosedSilenceableFailure convertMixedValuesToInt(transform::TransformState &state, TransformOpInterface transformOp, SmallVectorImpl< int32_t > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, xegpu::CreateNdDescOp descOp, xegpu::DistributeLayoutAttr layout)
Replace xegpu.create_nd_desc op with a new one with the given layout.
DiagnosedSilenceableFailure getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef<::mlir::OpFoldResult > mixedSgLayout, ArrayRef<::mlir::OpFoldResult > mixedSgData, ArrayRef<::mlir::OpFoldResult > mixedInstData, xegpu::LayoutAttr &layoutAttr)
Generate xegpu::LayoutAttr from op mixed layout values.
static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef< int32_t > sgLayout, ArrayRef< int32_t > sgData, std::optional< ArrayRef< int32_t > > instData)
Create a layout attribute from the given parameters.
MLIRContext * getContext() const
Definition Builders.h:56
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIndex() const
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void registerTransformDialectExtension(DialectRegistry &registry)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.